import math
import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt
# %matplotlib inline

import torch
from torch import nn
from torch.nn import functional as F
# from torch.utils.data import TensorDataset, DataLoader
# from sklearn.metrics import mean_squared_error, mean_absolute_error


#################
##### MTGNN #####
#################
import numbers
from torch.nn import init

class nconv(nn.Module):
    def __init__(self):
        super(nconv,self).__init__()

    def forward(self,x, A):
        x = torch.einsum('ncvl,vw->ncwl',(x,A))
        return x.contiguous()

class dy_nconv(nn.Module):
    def __init__(self):
        super(dy_nconv,self).__init__()

    def forward(self,x, A):
        x = torch.einsum('ncvl,nvwl->ncwl',(x,A))
        return x.contiguous()

class linear(nn.Module):
    def __init__(self,c_in,c_out,bias=True):
        super(linear,self).__init__()
        self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=bias)

    def forward(self,x):
        return self.mlp(x)


class prop(nn.Module):
    def __init__(self,c_in,c_out,gdep,dropout,alpha):
        super(prop, self).__init__()
        self.nconv = nconv()
        self.mlp = linear(c_in,c_out)
        self.gdep = gdep
        self.dropout = dropout
        self.alpha = alpha

    def forward(self,x,adj):
        adj = adj + torch.eye(adj.size(0)).to(x.device)
        d = adj.sum(1)
        h = x
        dv = d
        a = adj / dv.view(-1, 1)
        for i in range(self.gdep):
            h = self.alpha*x + (1-self.alpha)*self.nconv(h,a)
        ho = self.mlp(h)
        return ho


class mixprop(nn.Module):
    def __init__(self,c_in,c_out,gdep,dropout,alpha):
        super(mixprop, self).__init__()
        self.nconv = nconv()
        self.mlp = linear((gdep+1)*c_in,c_out)
        self.gdep = gdep
        self.dropout = dropout
        self.alpha = alpha


    def forward(self,x,adj):
        adj = adj + torch.eye(adj.size(0)).to(x.device)
        d = adj.sum(1)
        h = x
        out = [h]
        a = adj / d.view(-1, 1)
        for i in range(self.gdep):
            h = self.alpha*x + (1-self.alpha)*self.nconv(h,a)
            out.append(h)
        ho = torch.cat(out,dim=1)
        ho = self.mlp(ho)
        return ho

class dy_mixprop(nn.Module):
    def __init__(self,c_in,c_out,gdep,dropout,alpha):
        super(dy_mixprop, self).__init__()
        self.nconv = dy_nconv()
        self.mlp1 = linear((gdep+1)*c_in,c_out)
        self.mlp2 = linear((gdep+1)*c_in,c_out)

        self.gdep = gdep
        self.dropout = dropout
        self.alpha = alpha
        self.lin1 = linear(c_in,c_in)
        self.lin2 = linear(c_in,c_in)


    def forward(self,x):
        #adj = adj + torch.eye(adj.size(0)).to(x.device)
        #d = adj.sum(1)
        x1 = torch.tanh(self.lin1(x))
        x2 = torch.tanh(self.lin2(x))
        adj = self.nconv(x1.transpose(2,1),x2)
        adj0 = torch.softmax(adj, dim=2)
        adj1 = torch.softmax(adj.transpose(2,1), dim=2)

        h = x
        out = [h]
        for i in range(self.gdep):
            h = self.alpha*x + (1-self.alpha)*self.nconv(h,adj0)
            out.append(h)
        ho = torch.cat(out,dim=1)
        ho1 = self.mlp1(ho)


        h = x
        out = [h]
        for i in range(self.gdep):
            h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1)
            out.append(h)
        ho = torch.cat(out, dim=1)
        ho2 = self.mlp2(ho)

        return ho1+ho2



class dilated_1D(nn.Module):
    def __init__(self, cin, cout, dilation_factor=2):
        super(dilated_1D, self).__init__()
        self.tconv = nn.ModuleList()
        self.kernel_set = [2,3,6,7]
        self.tconv = nn.Conv2d(cin,cout,(1,7),dilation=(1,dilation_factor))

    def forward(self,input):
        x = self.tconv(input)
        return x

class dilated_inception(nn.Module):
    def __init__(self, cin, cout, dilation_factor=2):
        super(dilated_inception, self).__init__()
        self.tconv = nn.ModuleList()
        self.kernel_set = [2,3,6,7]
        cout = int(cout/len(self.kernel_set))
        for kern in self.kernel_set:
            self.tconv.append(nn.Conv2d(cin,cout,(1,kern),dilation=(1,dilation_factor)))

    def forward(self,input):
        x = []
        for i in range(len(self.kernel_set)):
            x.append(self.tconv[i](input))
        for i in range(len(self.kernel_set)):
            x[i] = x[i][...,-x[-1].size(3):]
        x = torch.cat(x,dim=1)
        return x


class graph_constructor(nn.Module):
    def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
        super(graph_constructor, self).__init__()
        self.nnodes = nnodes
        if static_feat is not None:
            xd = static_feat.shape[1]
            self.lin1 = nn.Linear(xd, dim)
            self.lin2 = nn.Linear(xd, dim)
        else:
            self.emb1 = nn.Embedding(nnodes, dim)
            self.emb2 = nn.Embedding(nnodes, dim)
            self.lin1 = nn.Linear(dim,dim)
            self.lin2 = nn.Linear(dim,dim)

        self.device = device
        self.k = k
        self.dim = dim
        self.alpha = alpha
        self.static_feat = static_feat

    def forward(self, idx):
        if self.static_feat is None:
            nodevec1 = self.emb1(idx)
            nodevec2 = self.emb2(idx)
        else:
            nodevec1 = self.static_feat[idx,:]
            nodevec2 = nodevec1

        nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
        nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2))

        a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0))
        adj = F.relu(torch.tanh(self.alpha*a))
        mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device)
        mask.fill_(float('0'))
        s1,t1 = adj.topk(self.k,1)
        mask.scatter_(1,t1,s1.fill_(1))
        adj = adj*mask
        return adj

    def fullA(self, idx):
        if self.static_feat is None:
            nodevec1 = self.emb1(idx)
            nodevec2 = self.emb2(idx)
        else:
            nodevec1 = self.static_feat[idx,:]
            nodevec2 = nodevec1

        nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
        nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2))

        a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0))
        adj = F.relu(torch.tanh(self.alpha*a))
        return adj

class graph_global(nn.Module):
    def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
        super(graph_global, self).__init__()
        self.nnodes = nnodes
        self.A = nn.Parameter(torch.randn(nnodes, nnodes).to(device), requires_grad=True).to(device)

    def forward(self, idx):
        return F.relu(self.A)


class graph_undirected(nn.Module):
    def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
        super(graph_undirected, self).__init__()
        self.nnodes = nnodes
        if static_feat is not None:
            xd = static_feat.shape[1]
            self.lin1 = nn.Linear(xd, dim)
        else:
            self.emb1 = nn.Embedding(nnodes, dim)
            self.lin1 = nn.Linear(dim,dim)

        self.device = device
        self.k = k
        self.dim = dim
        self.alpha = alpha
        self.static_feat = static_feat

    def forward(self, idx):
        if self.static_feat is None:
            nodevec1 = self.emb1(idx)
            nodevec2 = self.emb1(idx)
        else:
            nodevec1 = self.static_feat[idx,:]
            nodevec2 = nodevec1

        nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
        nodevec2 = torch.tanh(self.alpha*self.lin1(nodevec2))

        a = torch.mm(nodevec1, nodevec2.transpose(1,0))
        adj = F.relu(torch.tanh(self.alpha*a))
        mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device)
        mask.fill_(float('0'))
        s1,t1 = adj.topk(self.k,1)
        mask.scatter_(1,t1,s1.fill_(1))
        adj = adj*mask
        return adj



class graph_directed(nn.Module):
    def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
        super(graph_directed, self).__init__()
        self.nnodes = nnodes
        if static_feat is not None:
            xd = static_feat.shape[1]
            self.lin1 = nn.Linear(xd, dim)
            self.lin2 = nn.Linear(xd, dim)
        else:
            self.emb1 = nn.Embedding(nnodes, dim)
            self.emb2 = nn.Embedding(nnodes, dim)
            self.lin1 = nn.Linear(dim,dim)
            self.lin2 = nn.Linear(dim,dim)

        self.device = device
        self.k = k
        self.dim = dim
        self.alpha = alpha
        self.static_feat = static_feat

    def forward(self, idx):
        if self.static_feat is None:
            nodevec1 = self.emb1(idx)
            nodevec2 = self.emb2(idx)
        else:
            nodevec1 = self.static_feat[idx,:]
            nodevec2 = nodevec1

        nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
        nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2))

        a = torch.mm(nodevec1, nodevec2.transpose(1,0))
        adj = F.relu(torch.tanh(self.alpha*a))
        mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device)
        mask.fill_(float('0'))
        s1,t1 = adj.topk(self.k,1)
        mask.scatter_(1,t1,s1.fill_(1))
        adj = adj*mask
        return adj


class LayerNorm(nn.Module):
    __constants__ = ['normalized_shape', 'weight', 'bias', 'eps', 'elementwise_affine']
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        super(LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.Tensor(*normalized_shape))
            self.bias = nn.Parameter(torch.Tensor(*normalized_shape))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.reset_parameters()


    def reset_parameters(self):
        if self.elementwise_affine:
            init.ones_(self.weight)
            init.zeros_(self.bias)

    def forward(self, input, idx):
        if self.elementwise_affine:
            return F.layer_norm(input, tuple(input.shape[1:]), self.weight[:,idx,:], self.bias[:,idx,:], self.eps)
        else:
            return F.layer_norm(input, tuple(input.shape[1:]), self.weight, self.bias, self.eps)

    def extra_repr(self):
        return '{normalized_shape}, eps={eps}, ' \
            'elementwise_affine={elementwise_affine}'.format(**self.__dict__)


class MTGNN(nn.Module): # gtnet
    def __init__(self, n_in=8, in_dim=1, n_out=1, input_size=3, target_size=3, subgraph_size=3, node_dim=40, dilation_exponential=1, conv_channels=32, residual_channels=16, skip_channels=32, end_channels=64, layers=3, gcn_depth=5, propalpha=0.05, tanhalpha=3, layer_norm_affline=True, gcn_true=True, buildA_true=True, device='cuda', predefined_A=None, static_feat=None, dropout=0.):
        super(MTGNN, self).__init__()
        self.gcn_true = gcn_true
        self.buildA_true = buildA_true
        self.input_size = input_size
        self.dropout = dropout
        self.predefined_A = predefined_A
        self.filter_convs = nn.ModuleList()
        self.gate_convs = nn.ModuleList()
        self.residual_convs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()
        self.gconv1 = nn.ModuleList()
        self.gconv2 = nn.ModuleList()
        self.norm = nn.ModuleList()
        self.start_conv = nn.Conv2d(in_channels=in_dim,
                                    out_channels=residual_channels,
                                    kernel_size=(1, 1))
        self.gc = graph_constructor(input_size, subgraph_size, node_dim, device, alpha=tanhalpha, static_feat=static_feat)

        self.n_in = n_in
        kernel_size = 7
        if dilation_exponential>1:
            self.receptive_field = int(1+(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1))
        else:
            self.receptive_field = layers*(kernel_size-1) + 1

        for i in range(1):
            if dilation_exponential>1:
                rf_size_i = int(1 + i*(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1))
            else:
                rf_size_i = i*layers*(kernel_size-1)+1
            new_dilation = 1
            for j in range(1,layers+1):
                if dilation_exponential > 1:
                    rf_size_j = int(rf_size_i + (kernel_size-1)*(dilation_exponential**j-1)/(dilation_exponential-1))
                else:
                    rf_size_j = rf_size_i+j*(kernel_size-1)

                self.filter_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation))
                self.gate_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation))
                self.residual_convs.append(nn.Conv2d(in_channels=conv_channels,
                                                    out_channels=residual_channels,
                                                 kernel_size=(1, 1)))
                if self.n_in>self.receptive_field:
                    self.skip_convs.append(nn.Conv2d(in_channels=conv_channels,
                                                    out_channels=skip_channels,
                                                    kernel_size=(1, self.n_in-rf_size_j+1)))
                else:
                    self.skip_convs.append(nn.Conv2d(in_channels=conv_channels,
                                                    out_channels=skip_channels,
                                                    kernel_size=(1, self.receptive_field-rf_size_j+1)))

                if self.gcn_true:
                    self.gconv1.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha))
                    self.gconv2.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha))

                if self.n_in>self.receptive_field:
                    self.norm.append(LayerNorm((residual_channels, input_size, self.n_in - rf_size_j + 1),elementwise_affine=layer_norm_affline))
                else:
                    self.norm.append(LayerNorm((residual_channels, input_size, self.receptive_field - rf_size_j + 1),elementwise_affine=layer_norm_affline))

                new_dilation *= dilation_exponential

        self.layers = layers
        self.end_conv_1 = nn.Conv2d(in_channels=skip_channels,
                                             out_channels=end_channels,
                                             kernel_size=(1,1),
                                             bias=True)
        self.end_conv_2 = nn.Conv2d(in_channels=end_channels,
                                             out_channels=n_out,
                                             kernel_size=(1,1),
                                             bias=True)
        if self.n_in > self.receptive_field:
            self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.n_in), bias=True)
            self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, self.n_in-self.receptive_field+1), bias=True)

        else:
            self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.receptive_field), bias=True)
            self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, 1), bias=True)


        self.idx = torch.arange(self.input_size).to(device)


    def forward(self, input, idx=None):
        input = torch.unsqueeze(input,dim=1)
        input = input.transpose(2,3)
        seq_len = input.size(3)
        assert seq_len==self.n_in, 'input sequence length not equal to preset sequence length'

        if self.n_in<self.receptive_field:
            input = nn.functional.pad(input,(self.receptive_field-self.n_in,0,0,0))

        if self.gcn_true:
            if self.buildA_true:
                if idx is None:
                    adp = self.gc(self.idx)
                else:
                    adp = self.gc(idx)
            else:
                adp = self.predefined_A

        x = self.start_conv(input)
        skip = self.skip0(F.dropout(input, self.dropout, training=self.training))
        for i in range(self.layers):
            residual = x
            filter = self.filter_convs[i](x)
            filter = torch.tanh(filter)
            gate = self.gate_convs[i](x)
            gate = torch.sigmoid(gate)
            x = filter * gate
            x = F.dropout(x, self.dropout, training=self.training)
            s = x
            s = self.skip_convs[i](s)
            skip = s + skip
            if self.gcn_true:
                x = self.gconv1[i](x, adp)+self.gconv2[i](x, adp.transpose(1,0))
            else:
                x = self.residual_convs[i](x)

            x = x + residual[:, :, :, -x.size(3):]
            if idx is None:
                x = self.norm[i](x,self.idx)
            else:
                x = self.norm[i](x,idx)

        skip = self.skipE(x) + skip
        x = F.relu(skip)
        x = F.relu(self.end_conv_1(x))
        x = self.end_conv_2(x)
        x = x.squeeze(-1) # .permute(0,3,2,1)
        return x


###################
##### StemGNN #####
###################

class GLU(nn.Module):
    def __init__(self, input_channel, output_channel):
        super(GLU, self).__init__()
        self.linear_left = nn.Linear(input_channel, output_channel)
        self.linear_right = nn.Linear(input_channel, output_channel)

    def forward(self, x):
        return torch.mul(self.linear_left(x), torch.sigmoid(self.linear_right(x)))


class StockBlockLayer(nn.Module):
    def __init__(self, n_in, unit, multi_layer, stack_cnt=0):
        super(StockBlockLayer, self).__init__()
        self.n_in = n_in
        self.unit = unit
        self.stack_cnt = stack_cnt
        self.multi = multi_layer
        self.weight = nn.Parameter(
            torch.Tensor(1, 3 + 1, 1, self.n_in * self.multi,
                         self.multi * self.n_in))  # [K+1, 1, in_c, out_c]
        nn.init.xavier_normal_(self.weight)
        self.forecast = nn.Linear(self.n_in * self.multi, self.n_in * self.multi)
        self.forecast_result = nn.Linear(self.n_in * self.multi, self.n_in)
        if self.stack_cnt == 0:
            self.backcast = nn.Linear(self.n_in * self.multi, self.n_in)
        self.backcast_short_cut = nn.Linear(self.n_in, self.n_in)
        self.relu = nn.ReLU()
        self.GLUs = nn.ModuleList()
        self.output_channel = 4 * self.multi
        for i in range(3):
            if i == 0:
                self.GLUs.append(GLU(self.n_in * 4, self.n_in * self.output_channel))
                self.GLUs.append(GLU(self.n_in * 4, self.n_in * self.output_channel))
            elif i == 1:
                self.GLUs.append(GLU(self.n_in * self.output_channel, self.n_in * self.output_channel))
                self.GLUs.append(GLU(self.n_in * self.output_channel, self.n_in * self.output_channel))
            else:
                self.GLUs.append(GLU(self.n_in * self.output_channel, self.n_in * self.output_channel))
                self.GLUs.append(GLU(self.n_in * self.output_channel, self.n_in * self.output_channel))

    def spe_seq_cell(self, input):
        batch_size, k, input_channel, node_cnt, n_in = input.size()
        input = input.view(batch_size, -1, node_cnt, n_in)
        # ffted = torch.rfft(input, 1, onesided=False)
        # real = ffted[..., 0].permute(0, 2, 1, 3).contiguous().reshape(batch_size, node_cnt, -1)
        # img = ffted[..., 1].permute(0, 2, 1, 3).contiguous().reshape(batch_size, node_cnt, -1)
        ffted = torch.fft.fft(input)
        real = ffted.real.permute(0, 2, 1, 3).contiguous().reshape(batch_size, node_cnt, -1)
        img = ffted.imag.permute(0, 2, 1, 3).contiguous().reshape(batch_size, node_cnt, -1)
        for i in range(3):
            real = self.GLUs[i * 2](real)
            img = self.GLUs[2 * i + 1](img)
        real = real.reshape(batch_size, node_cnt, 4, -1).permute(0, 2, 1, 3).contiguous()
        img = img.reshape(batch_size, node_cnt, 4, -1).permute(0, 2, 1, 3).contiguous()
        # n_in_as_inner = torch.cat([real.unsqueeze(-1), img.unsqueeze(-1)], dim=-1)
        n_in_as_inner = torch.complex(real, img)
        # iffted = torch.irfft(n_in_as_inner, 1, onesided=False)
        iffted = torch.fft.ifft(n_in_as_inner)
        return iffted.real # .float(), .to(torch.float32)

    def forward(self, x, mul_L):
        mul_L = mul_L.unsqueeze(1)
        x = x.unsqueeze(1)
        gfted = torch.matmul(mul_L, x)
        gconv_input = self.spe_seq_cell(gfted).unsqueeze(2)
        igfted = torch.matmul(gconv_input, self.weight)
        igfted = torch.sum(igfted, dim=1)
        forecast_source = torch.sigmoid(self.forecast(igfted).squeeze(1))
        forecast = self.forecast_result(forecast_source)
        if self.stack_cnt == 0:
            backcast_short = self.backcast_short_cut(x).squeeze(1)
            backcast_source = torch.sigmoid(self.backcast(igfted) - backcast_short)
        else:
            backcast_source = None
        return forecast, backcast_source


class StemGNN(nn.Module):
    def __init__(self, n_in=8, n_out=1, input_size=3, target_size=3, stack_cnt=2, multi_layer=3, dropout_rate=0., leaky_rate=0.):
        super(StemGNN, self).__init__() # 5, 0.2   , device='cuda'
        self.unit = input_size
        self.stack_cnt = stack_cnt
        self.alpha = leaky_rate
        self.n_in = n_in
        self.n_out = n_out
        self.weight_key = nn.Parameter(torch.zeros(size=(self.unit, 1)))
        nn.init.xavier_uniform_(self.weight_key.data, gain=1.414)
        self.weight_query = nn.Parameter(torch.zeros(size=(self.unit, 1)))
        nn.init.xavier_uniform_(self.weight_query.data, gain=1.414)
        self.GRU = nn.GRU(self.n_in, self.unit)
        self.multi_layer = multi_layer
        self.stock_block = nn.ModuleList()
        self.stock_block.extend(
            [StockBlockLayer(self.n_in, self.unit, self.multi_layer, stack_cnt=i) for i in range(self.stack_cnt)])
        # self.fc = nn.Sequential(
            # nn.Linear(int(self.n_in), int(self.n_in)),
            # nn.LeakyReLU(),
            # nn.Linear(int(self.n_in), self.n_out),
        # )
        self.leakyrelu = nn.LeakyReLU(self.alpha)
        self.dropout = nn.Dropout(p=dropout_rate)
        #self.to(device)
        ### refer TS2Vec
        self.tsdecoder = TSDecoder(n_in=n_in, n_out=n_out, input_size=input_size, target_size=target_size)#, dec_mode='fusing2'

    def get_laplacian(self, graph, normalize):
        """
        return the laplacian of the graph.
        :param graph: the graph structure without self loop, [N, N].
        :param normalize: whether to used the normalized laplacian.
        :return: graph laplacian.
        """
        if normalize:
            D = torch.diag(torch.sum(graph, dim=-1) ** (-1 / 2))
            L = torch.eye(graph.size(0), device=graph.device, dtype=graph.dtype) - torch.mm(torch.mm(D, graph), D)
        else:
            D = torch.diag(torch.sum(graph, dim=-1))
            L = D - graph
        return L

    def cheb_polynomial(self, laplacian):
        """
        Compute the Chebyshev Polynomial, according to the graph laplacian.
        :param laplacian: the graph laplacian, [N, N].
        :return: the multi order Chebyshev laplacian, [K, N, N].
        """
        N = laplacian.size(0)  # [N, N]
        laplacian = laplacian.unsqueeze(0)
        first_laplacian = torch.zeros([1, N, N], device=laplacian.device, dtype=torch.float)
        second_laplacian = laplacian
        third_laplacian = (2 * torch.matmul(laplacian, second_laplacian)) - first_laplacian
        forth_laplacian = 2 * torch.matmul(laplacian, third_laplacian) - second_laplacian
        multi_order_laplacian = torch.cat([first_laplacian, second_laplacian, third_laplacian, forth_laplacian], dim=0)
        return multi_order_laplacian

    def latent_correlation_layer(self, x):
        input, _ = self.GRU(x.permute(2, 0, 1).contiguous())
        input = input.permute(1, 0, 2).contiguous()
        attention = self.self_graph_attention(input)
        attention = torch.mean(attention, dim=0)
        degree = torch.sum(attention, dim=1)
        # laplacian is sym or not
        attention = 0.5 * (attention + attention.T)
        degree_l = torch.diag(degree)
        diagonal_degree_hat = torch.diag(1 / (torch.sqrt(degree) + 1e-7))
        laplacian = torch.matmul(diagonal_degree_hat,
                                 torch.matmul(degree_l - attention, diagonal_degree_hat))
        mul_L = self.cheb_polynomial(laplacian)
        return mul_L, attention

    def self_graph_attention(self, input):
        input = input.permute(0, 2, 1).contiguous()
        bat, N, fea = input.size()
        key = torch.matmul(input, self.weight_key)
        query = torch.matmul(input, self.weight_query)
        data = key.repeat(1, 1, N).view(bat, N * N, 1) + query.repeat(1, N, 1)
        data = data.squeeze(2)
        data = data.view(bat, N, -1)
        data = self.leakyrelu(data)
        attention = F.softmax(data, dim=2)
        attention = self.dropout(attention)
        return attention

    def graph_fft(self, input, eigenvectors):
        return torch.matmul(eigenvectors, input)

    def forward(self, x):
        mul_L, attention = self.latent_correlation_layer(x)
        X = x.unsqueeze(1).permute(0, 1, 3, 2).contiguous()
        result = []
        for stack_i in range(self.stack_cnt):
            forecast, X = self.stock_block[stack_i](X, mul_L)
            result.append(forecast)
        forecast = result[0] + result[1]
        # forecast = self.fc(forecast)
        forecast = self.tsdecoder(forecast).permute(0, 2, 1)
        if forecast.size()[-1] == 1:
            out = forecast.unsqueeze(1).squeeze(-1)#, attention
        else:
            out = forecast.permute(0, 2, 1).contiguous()#, attention
        return out


#######################
##### Transformer #####
#######################

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0., max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class Transformer(nn.Module):
    '''
    Transformer Model
    -----------------
    refer: 
    1. _PositionalEncoding, _TransformerModule ****
    in darts-master\darts\models\forecasting\transformer_model.py
    '''
    def __init__(self, n_in: int = 8, n_out: int = 1, input_size: int = 3, target_size: int = 3, d_model: int = 128, nhead: int = 4, num_encoder_layers: int = 3, num_decoder_layers: int = 3, dim_feedforward: int = 256, dropout: float = 0., activation: str = 'relu', custom_encoder = None, custom_decoder = None):
        super(Transformer, self).__init__() #512

        self.input_size = input_size
        self.target_size = target_size
        self.n_out = n_out

        self.encoder = nn.Linear(input_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout, n_in)

        # Defining the Transformer module
        self.transformer = nn.Transformer(d_model=d_model,
                                          nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout,
                                          activation=activation,
                                          custom_encoder=custom_encoder,
                                          custom_decoder=custom_decoder)

        self.decoder = nn.Linear(d_model, n_out * self.target_size) # fusing2

    def _create_transformer_inputs(self, data):
        # '_TimeSeriesSequentialDataset' stores time series in the
        # (batch_size, n_in, input_size) format. PyTorch's nn.Transformer
        # module needs it the (n_in, batch_size, input_size) format.
        # Therefore, the first two dimensions need to be swapped.
        src = data.permute(1, 0, 2)
        tgt = src[-1:, :, :]

        return src, tgt

    def forward(self, data):
        # Here we create 'src' and 'tgt', the inputs for the encoder and decoder
        # side of the Transformer architecture
        src, tgt = self._create_transformer_inputs(data)

        # "math.sqrt(self.input_size)" is a normalization factor
        # see section 3.2.1 in 'Attention is All you Need' by Vaswani et al. (2017)
        src = self.encoder(src) * math.sqrt(self.input_size)
        src = self.positional_encoding(src)

        tgt = self.encoder(tgt) * math.sqrt(self.input_size)
        tgt = self.positional_encoding(tgt)

        x = self.transformer(src=src, tgt=tgt)
        out = self.decoder(x)

        # Here we change the data format
        # from (1, batch_size, n_out * target_size)
        # to (batch_size, n_out, target_size)
        predictions = out[0, :, :]
        predictions = predictions.view(-1, self.n_out, self.target_size)

        return predictions

###################################################
## n-hits-main\src\models\transformer\transformer.py
## DataEmbedding -> PositionalEmbedding


## n-hits-main\src\models\components\transformer.py
# Cell
## prefix, Transformer
class TransformerEncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(TransformerEncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None):
        new_x, attn = self.attention(
            x, x, x,
            attn_mask=attn_mask
        )
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm2(x + y), attn


class TransformerEncoder(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(TransformerEncoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None):
        # x [B, L, D]
        attns = []
        if self.conv_layers is not None:
            for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
                x, attn = attn_layer(x, attn_mask=attn_mask)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(x, attn_mask=attn_mask)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns


class TransformerDecoderLayer(nn.Module):
    def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
                 dropout=0.1, activation="relu"):
        super(TransformerDecoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, cross, x_mask=None, cross_mask=None):
        x = x + self.dropout(self.self_attention(
            x, x, x,
            attn_mask=x_mask
        )[0])
        x = self.norm1(x)

        x = x + self.dropout(self.cross_attention(
            x, cross, cross,
            attn_mask=cross_mask
        )[0])

        y = x = self.norm2(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm3(x + y)


class TransformerDecoder(nn.Module):
    def __init__(self, layers, norm_layer=None, projection=None):
        super(TransformerDecoder, self).__init__()
        self.layers = nn.ModuleList(layers)
        self.norm = norm_layer
        self.projection = projection

    def forward(self, x, cross, x_mask=None, cross_mask=None):
        for layer in self.layers:
            x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)

        if self.norm is not None:
            x = self.norm(x)

        if self.projection is not None:
            x = self.projection(x)
        return x

## n-hits-main\src\models\components\embed.py
# Cell
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEmbedding, self).__init__()
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0) # pe:  torch.Size([1, timesteps, d_model])
        self.register_buffer('pe', pe)

    def forward(self, x):# pe input x:  torch.Size([batch_size, timesteps, d_model])
        return self.pe[:, :x.size(1)]

## n-hits-main\src\models\components\selfattention.py
# from math import sqrt

# Cell
class TriangularCausalMask():
    def __init__(self, B, L, device="cpu"):
        mask_shape = [B, 1, L, L]
        with torch.no_grad():
            self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)

    @property
    def mask(self):
        return self._mask

# Cell
class FullAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(FullAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, attn_mask):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / math.sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)

        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)

            scores.masked_fill_(attn_mask.mask, -np.inf)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        if self.output_attention:
            return (V.contiguous(), A)
        else:
            return (V.contiguous(), None)


class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_attention(
            queries,
            keys,
            values,
            attn_mask
        )
        out = out.view(B, L, -1)

        return self.out_projection(out), attn


## n-hits-main\src\models\transformer\transformer.py
## DataEmbedding -> PositionalEmbedding
# Cell
class _Transformer(nn.Module):
    '''
    Vanilla Transformer with O(L^2) complexity
    '''
    def __init__(self, n_in: int = 8, n_out: int = 1, input_size: int = 3, target_size: int = 3, d_model: int = 128, factor: int = 3, n_heads: int = 4, d_ff: int = 512, dropout: float = 0., activation: str = 'relu', e_layers: int = 3, d_layers: int = 3, output_attention: bool = False):
        # pred_len, output_attention, dropout, enc_in, dec_in, n_in: int = 8, n_out: int = 1, input_size: int = 3, target_size: int = 3, d_model: int = 128, nhead: int = 4, num_encoder_layers: int = 3, num_decoder_layers: int = 3, dim_feedforward: int = 512, dropout: float = 0., activation: str = 'relu', output_attention: bool = False #embed, freq, 
        super(_Transformer, self).__init__()
        self.pred_len = n_out #pred_len
        self.c_out = target_size #c_out
        self.output_attention = output_attention

        ## Embedding
        # self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout)
        # self.dec_embedding = DataEmbedding(dec_in, d_model, embed, freq, dropout)
        # self.enc_embedding = PositionalEmbedding(d_model, enc_in)
        # self.dec_embedding = PositionalEmbedding(d_model, dec_in)
        ## assume enc_in = dec_in = n_in
        self.pos_embedding = PositionalEmbedding(d_model, n_in)
        self.fc = nn.Linear(input_size, d_model)

        # Encoder
        self.encoder = TransformerEncoder(
            [
                TransformerEncoderLayer(
                    AttentionLayer(
                        FullAttention(False, factor, attention_dropout=dropout,
                                      output_attention=output_attention), d_model, n_heads),
                    d_model,
                    d_ff,
                    dropout=dropout,
                    activation=activation
                ) for l in range(e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(d_model)
        )
        # Decoder
        self.decoder = TransformerDecoder(
            [
                TransformerDecoderLayer(
                    AttentionLayer(
                        FullAttention(True, factor, attention_dropout=dropout, output_attention=False),
                        d_model, n_heads),
                    AttentionLayer(
                        FullAttention(False, factor, attention_dropout=dropout, output_attention=False),
                        d_model, n_heads),
                    d_model,
                    d_ff,
                    dropout=dropout,
                    activation=activation,
                )
                for l in range(d_layers)
            ],
            norm_layer=torch.nn.LayerNorm(d_model),
            projection=nn.Linear(d_model, self.c_out * self.pred_len, bias=True) # fusing2
        )

    def _create_transformer_inputs(self, data, flip01=False):
        if flip01:
            src = data.permute(1, 0, 2)
            tgt = src[-1:, :, :]
        else:
            src = data
            tgt = src[:, -1:, :]
        return src, tgt

    def forward(self, data,#x_mark_enc, x_mark_dec,
                enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
        x_enc, x_dec = self._create_transformer_inputs(data)

        x_enc = self.fc(x_enc)
        enc_emb = self.pos_embedding(x_enc)#, x_mark_enc)
        enc_out = enc_emb + x_enc
        enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)

        x_dec = self.fc(x_dec)
        dec_emb = self.pos_embedding(x_dec)#, x_mark_dec)
        dec_out = dec_emb + x_dec
        dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)

        if self.output_attention:
            return dec_out.view(-1, self.pred_len, self.c_out), attns#[:, -self.pred_len:, :], attns
        else:
            return dec_out.view(-1, self.pred_len, self.c_out)#[:, -self.pred_len:, :]  # [B, L, D]



#####################
##### ConvTrans #####
#####################

class CausalConv1d(nn.Conv1d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 dilation=1,
                 groups=1,
                 bias=True):

        super(CausalConv1d, self).__init__(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=0,
            dilation=dilation,
            groups=groups,
            bias=bias)
        
        self.__padding = (kernel_size - 1) * dilation
        
    def forward(self, input):
        return super(CausalConv1d, self).forward(F.pad(input, (self.__padding, 0)))


class context_embedding(nn.Module):
    def __init__(self,in_channels=1,embedding_size=256,k=5):
        super(context_embedding,self).__init__()
        self.causal_convolution = CausalConv1d(in_channels,embedding_size,kernel_size=k)

    def forward(self,x):
        x = self.causal_convolution(x)
        return torch.tanh(x)

## see Transformer above
# class PositionalEncoding(nn.Module):
    # def __init__(self, d_model, max_len=5000):
        # super(PositionalEncoding, self).__init__()
        # pe = torch.zeros(max_len, d_model)
        # position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        # pe[:, 0::2] = torch.sin(position * div_term)
        # pe[:, 1::2] = torch.cos(position * div_term)
        # pe = pe.unsqueeze(0).transpose(0, 1)
        # self.register_buffer('pe', pe)

    # def forward(self, x):
        # return x + self.pe[:x.size(0), :]


class ConvTrans(nn.Module): # TransformerTimeSeries
    """
    Time Series application of transformers based on paper
    
    causal_convolution_layer parameters:
        in_channels: the number of features per time point
        out_channels: the number of features outputted per time point
        kernel_size: k is the width of the 1-D sliding kernel
        
    nn.Transformer parameters:
        d_model: the size of the embedding vector (input)
    
    PositionalEncoding parameters:
        d_model: the size of the embedding vector (positional vector)
        dropout: the dropout to be used on the sum of positional+embedding vector
    
    """
    def __init__(self, n_in=8, n_out=1, input_size=3, target_size=3, feature_size=128, kernel_size=7, nhead=4, num_layers=1, dropout=0):
        super(ConvTrans,self).__init__() # 256, 8
        self.model_type = 'Transformer'
        self.src_mask = None
        self.input_embedding = context_embedding(in_channels=input_size, embedding_size=feature_size, k=kernel_size)
        self.pos_encoder = PositionalEncoding(feature_size)
        # self.positional_embedding = nn.Embedding(feature_size*2,feature_size)

        self.encode_layer = nn.TransformerEncoderLayer(d_model=feature_size, nhead=nhead, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(self.encode_layer, num_layers=num_layers)

        self.fc1 = nn.Linear(feature_size, input_size) # decoder
        # self.fc2 = nn.Linear(n_in, n_out)
        ### refer TS2Vec
        self.tsdecoder = TSDecoder(n_in=n_in, n_out=n_out, input_size=input_size, target_size=target_size)#, dec_mode='fusing2'

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.fc1.bias.data.zero_()
        self.fc1.weight.data.uniform_(-initrange, initrange)

    def forward(self, src): #x,y,attention_masks
        
        # concatenate observed points and time covariate
        # (Batchsize*feature_size*n_time_points)
        # 这里的输入是一个二维矩阵--2*时间长度，具有两个特征，一个是当前的值，另一个是当前的时间（协变量）
        # z = torch.cat((y.unsqueeze(1),x.unsqueeze(1)),1)

        # input_embedding returns shape (Batch size,embedding size,sequence len) -> need (sequence len,Batch size,embedding_size)
        # z_embedding = self.input_embedding(z).permute(2,0,1)
        
        # get my positional embeddings (Batch size, sequence_len, embedding_size) -> need (sequence len,Batch size,embedding_size)
        # positional_embeddings = self.positional_embedding(x.type(torch.long)).permute(1,0,2)
        
        # input_embedding = z_embedding+positional_embeddings

        # The shape of the 2D attn_mask, (timesteps, timesteps)
        # timesteps = src.size(1)
        if self.src_mask is None or self.src_mask.size(1) != src.size(1):
            device = src.device
            mask = self._generate_square_subsequent_mask(src.size(1)).to(device)
            self.src_mask = mask

        # bs = 64, seq_len = 10, emb_sz = 256
        # print(src.size())# torch.Size([64, 10, 1])
        # self.pos_encoder(src)#.permute(0,2,1)
        # src_pe = torch.stack([self.pos_encoder(src[:,:,[i]]) for i in range(src.size(2))], dim=-1) # 不需要新的维度
        src_pe = self.pos_encoder(src[:,:,[0]]) # 取一个就好
        # print(src_pe.size())# torch.Size([64, 10, 256])
        src_ce = self.input_embedding(src.permute(0,2,1))
        # print(src_ce.size())# torch.Size([64, 256, 10])
        src_embedding = src_pe.permute(1,0,2) + src_ce.permute(2,0,1)
        # src_embedding torch.Size([10, 64, 256])
        transformer_embedding = self.transformer_encoder(src_embedding, self.src_mask).permute(1,0,2)
        # transformer_embedding torch.Size([64, 10, 256])
        # transformer_embedding = self.transformer_encoder(input_embedding,attention_masks)

        output = self.fc1(transformer_embedding) #
        # print(output.size()) # torch.Size([bs, seq_len, input_size])
        # return self.fc2(output.permute(0,2,1)).permute(0,2,1)
        # output.squeeze()
        return self.tsdecoder(output)

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask


####################
##### Informer #####
####################

## n-hits-main\src\models\transformer\informer.py

## see _Transformer above,             Decoder, DecoderLayer, Encoder, EncoderLayer
# from ..components.transformer import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer

# Cell
class ConvLayer(nn.Module):
    def __init__(self, c_in):
        super(ConvLayer, self).__init__()
        self.downConv = nn.Conv1d(in_channels=c_in,
                                  out_channels=c_in,
                                  kernel_size=3,
                                  padding=2,
                                  padding_mode='circular')
        self.norm = nn.BatchNorm1d(c_in)
        self.activation = nn.ELU()
        self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.downConv(x.permute(0, 2, 1))
        x = self.norm(x)
        x = self.activation(x)
        x = self.maxPool(x)
        x = x.transpose(1, 2)
        return x

## see _Transformer above,               AttentionLayer
# from ..components.selfattention import AttentionLayer, ProbAttention

class ProbMask():
    def __init__(self, B, H, L, index, scores, device="cpu"):
        _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
        _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
        indicator = _mask_ex[torch.arange(B)[:, None, None],
                    torch.arange(H)[None, :, None],
                    index, :].to(device)
        self._mask = indicator.view(scores.shape).to(device)

    @property
    def mask(self):
        return self._mask

class ProbAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(ProbAttention, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def _prob_QK(self, Q, K, sample_k, n_top):  # n_top: c*ln(L_q)
        # Q [B, H, L, D]
        B, H, L_K, E = K.shape
        _, _, L_Q, _ = Q.shape

        # calculate the sampled Q_K
        K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
        index_sample = torch.randint(L_K, (L_Q, sample_k))  # real U = U_part(factor*ln(L_k))*L_q
        K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
        Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()

        # find the Top_k query with sparisty measurement
        M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
        M_top = M.topk(n_top, sorted=False)[1]

        # use the reduced Q to calculate Q_K
        Q_reduce = Q[torch.arange(B)[:, None, None],
                   torch.arange(H)[None, :, None],
                   M_top, :]  # factor*ln(L_q)
        Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))  # factor*ln(L_q)*L_k

        return Q_K, M_top

    def _get_initial_context(self, V, L_Q):
        B, H, L_V, D = V.shape
        if not self.mask_flag:
            # V_sum = V.sum(dim=-2)
            V_sum = V.mean(dim=-2)
            contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
        else:  # use mask
            assert (L_Q == L_V)  # requires that L_Q == L_V, i.e. for self-attention only
            contex = V.cumsum(dim=-2)
        return contex

    def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
        B, H, L_V, D = V.shape

        if self.mask_flag:
            attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
            scores.masked_fill_(attn_mask.mask, -np.inf)

        attn = torch.softmax(scores, dim=-1)  # nn.Softmax(dim=-1)(scores)

        context_in[torch.arange(B)[:, None, None],
        torch.arange(H)[None, :, None],
        index, :] = torch.matmul(attn, V).type_as(context_in)
        if self.output_attention:
            attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
            attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
            return (context_in, attns)
        else:
            return (context_in, None)

    def forward(self, queries, keys, values, attn_mask):
        B, L_Q, H, D = queries.shape
        _, L_K, _, _ = keys.shape

        queries = queries.transpose(2, 1)
        keys = keys.transpose(2, 1)
        values = values.transpose(2, 1)

        U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item()  # c*ln(L_k)
        u = self.factor * np.ceil(np.log(L_Q)).astype('int').item()  # c*ln(L_q)

        U_part = U_part if U_part < L_K else L_K
        u = u if u < L_Q else L_Q

        scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)

        # add scale factor
        scale = self.scale or 1. / math.sqrt(D)
        if scale is not None:
            scores_top = scores_top * scale
        # get the context
        context = self._get_initial_context(values, L_Q)
        # update the context with selected top_k queries
        context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)

        return context.contiguous(), attn


# Cell
class Informer(nn.Module):
    '''
    Informer with Propspare attention in O(LlogL) complexity
    '''
    def __init__(self, n_in: int = 8, n_out: int = 1, input_size: int = 3, target_size: int = 3, d_model: int = 128, factor: int = 3, n_heads: int = 4, d_ff: int = 512, dropout: float = 0., activation: str = 'relu', e_layers: int = 3, d_layers: int = 3, output_attention: bool = False, distil: bool = False):
        #(self, pred_len, output_attention, enc_in, dec_in, d_model, c_out, embed, freq, dropout, factor, n_heads, d_ff, activation, e_layers, d_layers, distil):
        super(Informer, self).__init__()
        self.pred_len = n_out #pred_len
        self.c_out = target_size #c_out
        self.output_attention = output_attention

        # Embedding
        # self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout)
        # self.dec_embedding = DataEmbedding(dec_in, d_model, embed, freq, dropout)
        ## assume enc_in = dec_in = n_in
        self.pos_embedding = PositionalEmbedding(d_model, n_in)
        self.fc1 = nn.Linear(input_size, d_model)
        # self.fc2 = nn.Linear(n_in, n_out)
        ### refer TS2Vec
        self.tsdecoder = TSDecoder(n_in=n_in, n_out=n_out, input_size=input_size, target_size=target_size)#, dec_mode='fusing2'

        # Encoder
        self.encoder = TransformerEncoder(
            [
                TransformerEncoderLayer(
                    AttentionLayer(
                        ProbAttention(False, factor, attention_dropout=dropout,
                                      output_attention=output_attention), d_model, n_heads),
                    d_model,
                    d_ff,
                    dropout=dropout,
                    activation=activation
                ) for l in range(e_layers)
            ],
            [
                ConvLayer(
                    d_model
                ) for l in range(e_layers - 1)
            ] if distil else None,
            norm_layer=torch.nn.LayerNorm(d_model)
        )
        # Decoder
        self.decoder = TransformerDecoder(
            [
                TransformerDecoderLayer(
                    AttentionLayer(
                        ProbAttention(True, factor, attention_dropout=dropout, 
                                    output_attention=False), d_model, n_heads),
                    AttentionLayer(
                        ProbAttention(False, factor, attention_dropout=dropout, 
                                    output_attention=False), d_model, n_heads),
                    d_model,
                    d_ff,
                    dropout=dropout,
                    activation=activation,
                )
                for l in range(d_layers)
            ],
            norm_layer=torch.nn.LayerNorm(d_model),
            projection=nn.Linear(d_model, self.c_out, bias=True) # * self.pred_len
        )

    def _create_transformer_inputs(self, data, flip01=False):
        if flip01:
            src = data.permute(1, 0, 2)
            tgt = src #[-1:, :, :]
        else: # Generative Inference 
            src = data
            tgt = src #[:, -1:, :]
        return src, tgt

    def forward(self, data,#x_mark_enc, x_mark_dec,
                enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
        x_enc, x_dec = self._create_transformer_inputs(data)

        x_enc = self.fc1(x_enc)
        enc_emb = self.pos_embedding(x_enc)#, x_mark_enc)
        src = enc_emb + x_enc
        # print('src: ', src.size())
        x_dec = self.fc1(x_dec)
        dec_emb = self.pos_embedding(x_dec)#, x_mark_dec)
        tgt = dec_emb + x_dec
        # print('tgt: ', tgt.size())

        memory, attns = self.encoder(src, attn_mask=enc_self_mask)
        # print('memory: ', memory.size())

        dec_out = self.decoder(tgt, memory, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
        # dec_out[:, -self.pred_len:, :], 這種取法太脆弱了，訓練難。？
        # output = self.fc2(dec_out.permute(0,2,1)).permute(0,2,1)
        output = self.tsdecoder(dec_out)

        if self.output_attention:
            return output, attns
        else:
            return output # size == (-1, self.pred_len, self.c_out)


######################
##### Autoformer #####
######################

## n-hits-main\src\models\transformer\autoformer.py

# from ..components.autoformer import Encoder, Decoder, EncoderLayer, DecoderLayer,  my_Layernorm, series_decomp

# Cell
class my_Layernorm(nn.Module):
    '''
    Special designed layernorm for the seasonal part
    '''
    def __init__(self, channels):
        super(my_Layernorm, self).__init__()
        self.layernorm = nn.LayerNorm(channels)

    def forward(self, x):
        x_hat = self.layernorm(x)
        bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
        return x_hat - bias


class moving_avg(nn.Module):
    '''
    Moving average block to highlight the trend of time series
    '''
    def __init__(self, kernel_size, stride):
        super(moving_avg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x


class series_decomp(nn.Module):
    '''
    Series decomposition block
    '''
    def __init__(self, kernel_size):
        super(series_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean


class AutoformerEncoderLayer(nn.Module):
    '''
    Autoformer encoder layer with the progressive decomposition architecture
    '''
    def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
        super(AutoformerEncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
        self.decomp1 = series_decomp(moving_avg)
        self.decomp2 = series_decomp(moving_avg)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None):
        new_x, attn = self.attention(
            x, x, x,
            attn_mask=attn_mask
        )
        x = x + self.dropout(new_x)
        x, _ = self.decomp1(x)
        y = x
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))
        res, _ = self.decomp2(x + y)
        return res, attn


class AutoformerEncoder(nn.Module):
    '''
    Autoformer encoder
    '''
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(AutoformerEncoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None):
        attns = []
        if self.conv_layers is not None:
            for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
                x, attn = attn_layer(x, attn_mask=attn_mask)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(x, attn_mask=attn_mask)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns


class AutoformerDecoderLayer(nn.Module):
    '''
    Autoformer decoder layer with the progressive decomposition architecture
    '''
    def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None,
                 moving_avg=25, dropout=0.1, activation="relu"):
        super(AutoformerDecoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
        self.decomp1 = series_decomp(moving_avg)
        self.decomp2 = series_decomp(moving_avg)
        self.decomp3 = series_decomp(moving_avg)
        self.dropout = nn.Dropout(dropout)
        self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1,
                                    padding_mode='circular', bias=False)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, cross, x_mask=None, cross_mask=None):
        x = x + self.dropout(self.self_attention(
            x, x, x,
            attn_mask=x_mask
        )[0])
        x, trend1 = self.decomp1(x)
        x = x + self.dropout(self.cross_attention(
            x, cross, cross,
            attn_mask=cross_mask
        )[0])
        x, trend2 = self.decomp2(x)
        y = x
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))
        x, trend3 = self.decomp3(x + y)

        residual_trend = trend1 + trend2 + trend3
        residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)
        return x, residual_trend


class AutoformerDecoder(nn.Module):
    '''
    Autoformer decoder
    '''
    def __init__(self, layers, norm_layer=None, projection=None):
        super(AutoformerDecoder, self).__init__()
        self.layers = nn.ModuleList(layers)
        self.norm = norm_layer
        self.projection = projection

    def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
        for layer in self.layers:
            x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
            trend = trend + residual_trend

        if self.norm is not None:
            x = self.norm(x)

        if self.projection is not None:
            x = self.projection(x)
        return x, trend

# from ..components.autocorrelation import AutoCorrelation, AutoCorrelationLayer

# Cell
class AutoCorrelation(nn.Module):
    '''
    AutoCorrelation Mechanism with the following two phases:
    (1) period-based dependencies discovery
    (2) time delay aggregation
    This block can replace the self-attention family mechanism seamlessly.
    '''
    def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
        super(AutoCorrelation, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def time_delay_agg_training(self, values, corr):
        '''
        SpeedUp version of Autocorrelation (a batch-normalization style design)
        This is for the training phase.
        '''
        head = values.shape[1]
        channel = values.shape[2]
        length = values.shape[3]
        # find top k
        top_k = int(self.factor * math.log(length))
        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
        index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
        weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1)
        # aggregation
        tmp_values = values
        delays_agg = torch.zeros_like(values, dtype=torch.float, device=values.device)
        for i in range(top_k):
            pattern = torch.roll(tmp_values, -int(index[i]), -1)
            delays_agg = delays_agg + pattern * \
                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
        return delays_agg

    def time_delay_agg_inference(self, values, corr):
        '''
        SpeedUp version of Autocorrelation (a batch-normalization style design)
        This is for the inference phase.
        '''
        batch = values.shape[0]
        head = values.shape[1]
        channel = values.shape[2]
        length = values.shape[3]
        # index init
        init_index = torch.arange(length, device=values.device).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1)
        # find top k
        top_k = int(self.factor * math.log(length))
        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
        weights = torch.topk(mean_value, top_k, dim=-1)[0]
        delay = torch.topk(mean_value, top_k, dim=-1)[1]
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1)
        # aggregation
        tmp_values = values.repeat(1, 1, 1, 2)
        delays_agg = torch.zeros_like(values, dtype=torch.float, device=values.device)
        for i in range(top_k):
            tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
            pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
            delays_agg = delays_agg + pattern * \
                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
        return delays_agg

    def time_delay_agg_full(self, values, corr):
        '''
        Standard version of Autocorrelation
        '''
        batch = values.shape[0]
        head = values.shape[1]
        channel = values.shape[2]
        length = values.shape[3]
        # index init
        init_index = torch.arange(length, device=values.device).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1)
        # find top k
        top_k = int(self.factor * math.log(length))
        weights = torch.topk(corr, top_k, dim=-1)[0]
        delay = torch.topk(corr, top_k, dim=-1)[1]
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1)
        # aggregation
        tmp_values = values.repeat(1, 1, 1, 2)
        delays_agg = torch.zeros_like(values, dtype=torch.float, device=values.device)
        for i in range(top_k):
            tmp_delay = init_index + delay[..., i].unsqueeze(-1)
            pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
            delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
        return delays_agg

    def forward(self, queries, keys, values, attn_mask):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        if L > S:
            zeros = torch.zeros_like(queries[:, :(L - S), :], dtype=torch.float, device=queries.device)
            values = torch.cat([values, zeros], dim=1)
            keys = torch.cat([keys, zeros], dim=1)
        else:
            values = values[:, :L, :, :]
            keys = keys[:, :L, :, :]

        # period-based dependencies
        q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
        k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
        res = q_fft * torch.conj(k_fft)
        corr = torch.fft.irfft(res, dim=-1)

        # time delay agg
        if self.training:
            V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
        else:
            V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)

        if self.output_attention:
            return (V.contiguous(), corr.permute(0, 3, 1, 2))
        else:
            return (V.contiguous(), None)


class AutoCorrelationLayer(nn.Module):
    def __init__(self, correlation, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AutoCorrelationLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_correlation = correlation
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_correlation(
            queries,
            keys,
            values,
            attn_mask
        )
        out = out.view(B, L, -1)

        return self.out_projection(out), attn

# Cell
class Autoformer(nn.Module):
    '''
    Autoformer is the first method to achieve the series-wise connection,
    with inherent O(LlogL) complexity
    '''
    def __init__(self, n_in: int = 8, n_out: int = 1, input_size: int = 3, target_size: int = 3, d_model: int = 128, factor: int = 1, n_heads: int = 4, d_ff: int = 512, dropout: float = 0., activation: str = 'relu', e_layers: int = 2, d_layers: int = 1, output_attention: bool = False, moving_avg: int = 7, label_len: int = 10):
        # (self, seq_len, label_len, pred_len, output_attention, enc_in, dec_in, d_model, c_out, embed, freq, dropout, factor, n_heads, d_ff, moving_avg, activation, e_layers, d_layers):
        super(Autoformer, self).__init__()
        # self.seq_len = seq_len
        assert n_in >= (label_len + n_out)
        self.label_len = label_len
        self.pred_len = n_out #pred_len
        self.c_out = target_size #c_out
        self.output_attention = output_attention

        # Decomp
        kernel_size = moving_avg
        self.decomp = series_decomp(kernel_size)

        # Embedding
        # The series-wise connection inherently contains the sequential information.
        # Thus, we can discard the position embedding of transformers.
        # self.enc_embedding = DataEmbedding_wo_pos(enc_in, d_model, embed, freq, dropout)
        # self.dec_embedding = DataEmbedding_wo_pos(dec_in, d_model, embed, freq, dropout)
        ## assume enc_in = dec_in = n_in
        self.pos_embedding = PositionalEmbedding(d_model, n_in)
        self.fc1 = nn.Linear(input_size, d_model)
        # self.fc2 = nn.Linear(n_in, n_out)
        ### refer TS2Vec #,  n_in  <-  Algorithm 1 Overall Autoformer Procedure 18:
        self.tsdecoder = TSDecoder(n_in=(min(n_in, label_len) + n_out), n_out=n_out, input_size=input_size, target_size=target_size)#, dec_mode='fusing2'

        # Encoder
        self.encoder = AutoformerEncoder(
            [
                AutoformerEncoderLayer(
                    AutoCorrelationLayer(
                        AutoCorrelation(False, factor, attention_dropout=dropout,
                                        output_attention=output_attention), d_model, n_heads),
                    d_model,
                    d_ff,
                    moving_avg=moving_avg,
                    dropout=dropout,
                    activation=activation
                ) for l in range(e_layers)
            ],
            norm_layer=my_Layernorm(d_model)
        )
        # Decoder
        self.decoder = AutoformerDecoder(
            [
                AutoformerDecoderLayer(
                    AutoCorrelationLayer(
                        AutoCorrelation(True, factor, attention_dropout=dropout,
                                        output_attention=False), d_model, n_heads),
                    AutoCorrelationLayer(
                        AutoCorrelation(False, factor, attention_dropout=dropout,
                                        output_attention=False), d_model, n_heads),
                    d_model,
                    self.c_out,
                    d_ff,
                    moving_avg=moving_avg,
                    dropout=dropout,
                    activation=activation,
                )
                for l in range(d_layers)
            ],
            norm_layer=my_Layernorm(d_model),
            projection=nn.Linear(d_model, self.c_out, bias=True)# * self.pred_len
        )

    def _create_transformer_inputs(self, data, flip01=False):
        if flip01:
            src = data.permute(1, 0, 2)
            tgt = src#[-1:, :, :]
        else:
            src = data
            tgt = src#[:, -1:, :]
        return src, tgt

    def forward(self, data,#x_mark_enc, x_mark_dec,
                enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
        x_enc, x_dec = self._create_transformer_inputs(data)

        # decomp init
        mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1)
        zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]], device=x_enc.device)
        seasonal_init, trend_init = self.decomp(x_enc)
        # decoder input
        trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)
        seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1)

        ## enc
        enc_0 = self.fc1(x_enc)
        enc_emb = self.pos_embedding(enc_0)#, x_mark_enc)
        src = enc_emb + enc_0
        # print('src: ', src.size())
        memory, attns = self.encoder(src, attn_mask=enc_self_mask)
        # print('memory: ', memory.size())

        ## dec
        dec_0 = self.fc1(seasonal_init)
        dec_emb = self.pos_embedding(dec_0)#, x_mark_dec)
        tgt = dec_emb + dec_0
        # print('tgt: ', tgt.size())
        seasonal_part, trend_part = self.decoder(tgt, memory, x_mask=dec_self_mask, cross_mask=dec_enc_mask, trend=trend_init)

        # final
        dec_out = trend_part + seasonal_part
        # dec_out[:, -self.pred_len:, :], 這種取法太脆弱了，訓練難。？
        # output = dec_out[:, -self.pred_len:, :]
        # self.fc2(dec_out.permute(0,2,1)).permute(0,2,1)
        output = self.tsdecoder(dec_out)

        if self.output_attention:
            return output, attns
        else:
            return output


###############
##### TCN #####
###############

class ResidualBlock(nn.Module):

    def __init__(self, num_filters: int, kernel_size: int, dilation_base: int, dropout: float, weight_norm: bool, nr_blocks_below: int, num_layers: int, input_size: int, target_size: int):
        super(ResidualBlock, self).__init__()

        self.dilation_base = dilation_base
        self.kernel_size = kernel_size
        self.activation = nn.ReLU()
        self.dropout_fn = nn.Dropout(p=dropout) # dropout_fn
        self.num_layers = num_layers
        self.nr_blocks_below = nr_blocks_below

        input_size = input_size if nr_blocks_below == 0 else num_filters
        target_size = target_size if nr_blocks_below == num_layers - 1 else num_filters
        self.conv1 = nn.Conv1d(input_size, num_filters, kernel_size, dilation=(dilation_base ** nr_blocks_below))
        self.conv2 = nn.Conv1d(num_filters, target_size, kernel_size, dilation=(dilation_base ** nr_blocks_below))
        if weight_norm:
            self.conv1, self.conv2 = nn.utils.weight_norm(self.conv1), nn.utils.weight_norm(self.conv2)
        self.conv3 = nn.Conv1d(input_size, target_size, 1) if input_size != target_size else None
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.conv3 is not None: self.conv3.weight.data.normal_(0, 0.01)

    def forward(self, x):
        residual = x

        # first step
        left_padding = (self.dilation_base ** self.nr_blocks_below) * (self.kernel_size - 1)
        x = F.pad(x, (left_padding, 0))
        x = self.dropout_fn(self.activation(self.conv1(x)))

        # second step
        x = F.pad(x, (left_padding, 0))
        x = self.conv2(x)
        if self.nr_blocks_below < self.num_layers - 1:
            x = self.activation(x)
        x = self.dropout_fn(x)

        # add residual
        if self.conv1.in_channels != self.conv2.out_channels:
            residual = self.conv3(residual)
        x += residual

        return x

class _TCN(nn.Module):
    '''
    Temporal Convolutional Network
    ------------------------------
    refer: 
    1. _ResidualBlock, _TCNModule ***
    in darts-master\darts\models\forecasting\tcn_model.py
    2. TemporalBlock, TCN ****
    in tsai-main\tsai\models\TCN.py
    '''
    def __init__(self, n_in: int = 8, n_out: int = 1, input_size: int = 3, target_size: int = 3, kernel_size: int = 7, num_filters: int = 32, num_layers = 5, dilation_base: int = 2, weight_norm: bool = True, dropout: float = 0.):
        super(_TCN, self).__init__() #None #False

        # Defining parameters
        self.input_size = input_size
        self.n_in = n_in
        self.n_filters = num_filters
        self.kernel_size = kernel_size
        self.n_out = n_out
        self.target_size = target_size
        self.dilation_base = dilation_base
        self.dropout = dropout # nn.Dropout(p=dropout)

        # If num_layers is not passed, compute number of layers needed for full history coverage
        if num_layers is None and dilation_base > 1:
            num_layers = math.ceil(math.log((n_in - 1) * (dilation_base - 1) / (kernel_size - 1) / 2 + 1, dilation_base))
            print("[TCN] Number of layers chosen: " + str(num_layers) + '\r',end='')
        elif num_layers is None:
            num_layers = math.ceil((n_in - 1) / (kernel_size - 1) / 2)
            print("[TCN] Number of layers chosen: " + str(num_layers) + '\r',end='')
        self.num_layers = num_layers

        # Building TCN module
        self.res_blocks_list = []
        for i in range(num_layers):
            res_block = ResidualBlock(num_filters, kernel_size, dilation_base, self.dropout, weight_norm, i, num_layers, self.input_size, target_size)
            self.res_blocks_list.append(res_block)
        self.res_blocks = nn.ModuleList(self.res_blocks_list)
        # self.fc = nn.Linear(n_in, n_out)
        ### refer TS2Vec
        self.tsdecoder = TSDecoder(n_in=n_in, n_out=n_out, input_size=input_size, target_size=target_size)#, dec_mode='fusing2'

    def forward(self, x):
        # data is of size (batch_size, n_in, input_size)
        batch_size = x.size(0)
        x = x.transpose(1, 2)

        for i in range(self.num_layers):
            x = self.res_blocks[i](x) #self.res_blocks_list:

        x = x.transpose(1, 2)
        x = x.view(batch_size, self.n_in, self.target_size)

        output = self.tsdecoder(x)

        return output #self.fc(x.permute(0,2,1)).permute(0,2,1)


###################################################
#### a better implementation of TCN
import fastai.layers as fastai_layers
# import fastai.torch_core as tc
#     fastai.torch_core.Module, 
# Same as nn.Module, but no need for subclasses to call super().__init__
#     fastai.layers.Flatten(), 
# Flatten x to a single dimension, e.g. at end of a model.

# Cell
class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
    def __init__(self, ni, nf, ks, stride, dilation, padding, dropout=0.):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.utils.weight_norm(nn.Conv1d(ni,nf,ks,stride=stride,padding=padding,dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)
        self.conv2 = nn.utils.weight_norm(nn.Conv1d(nf,nf,ks,stride=stride,padding=padding,dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(ni,nf,1) if ni != nf else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None: self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)

def TemporalConvNet(c_in, layers, ks=2, dropout=0.):
    temp_layers = []
    for i in range(len(layers)):
        dilation_size = 2 ** i
        ni = c_in if i == 0 else layers[i-1]
        nf = layers[i]
        temp_layers += [TemporalBlock(ni, nf, ks, stride=1, dilation=dilation_size, padding=(ks-1) * dilation_size, dropout=dropout)]
    return nn.Sequential(*temp_layers)

# Cell
class GAP1d(nn.Module):
    "Global Adaptive Pooling + Flatten"
    def __init__(self, output_size=1):
        super(GAP1d, self).__init__()
        self.gap = nn.AdaptiveAvgPool1d(output_size)
        self.flatten = fastai_layers.Flatten() # fastai.layers.
    def forward(self, x):
        return self.flatten(self.gap(x))

class TCN(nn.Module):
    '''
    c_in, input_size (n_timeseries)
    c_out, target_size (n_timeseries)
    seq_len, n_in (timesteps), similar to RNN, adaptive for any length input seq
    '''
    def __init__(self, n_in=8, n_out=1, input_size=3, target_size=3, layers=5*[32], ks=7, conv_dropout=0., fc_dropout=0.):
        super(TCN, self).__init__()#8*[25]
        self.n_out, self.target_size = n_out, target_size
        self.tcn = TemporalConvNet(input_size, layers, ks=ks, dropout=conv_dropout)
        self.gap = GAP1d(output_size=n_out)
        self.dropout = nn.Dropout(fc_dropout) if fc_dropout else None
        self.linear = nn.Linear(layers[-1], target_size*n_out) # 'fusing2'
        self.init_weights()

    def init_weights(self):
        self.linear.weight.data.normal_(0, 0.01)

    def forward(self, x):
        # data is of size (batch_size, n_in, input_size)
        batch_size = x.size(0)
        x = self.tcn(x.transpose(2,1))
        x = self.gap(x)
        if self.dropout is not None: x = self.dropout(x)
        out = self.linear(x)
        return out.view(batch_size, self.n_out, self.target_size)


##################
##### TS2Vec #####
##################

## Encoder only
## \ts2vec-main\models\encoder.py
## \ts2vec-main\models\dilated_conv.py

class SamePadConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1):
        super().__init__()
        self.receptive_field = (kernel_size - 1) * dilation + 1
        padding = self.receptive_field // 2
        self.conv = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            padding=padding,
            dilation=dilation,
            groups=groups
        )
        self.remove = 1 if self.receptive_field % 2 == 0 else 0
        
    def forward(self, x):
        out = self.conv(x)
        if self.remove > 0:
            out = out[:, :, : -self.remove]
        return out
    
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False):
        super().__init__()
        self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation)
        self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation)
        self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None
    
    def forward(self, x):
        residual = x if self.projector is None else self.projector(x)
        x = F.gelu(x)
        x = self.conv1(x)
        x = F.gelu(x)
        x = self.conv2(x)
        return x + residual

class DilatedConvEncoder(nn.Module):
    def __init__(self, in_channels, channels, kernel_size):
        super().__init__()
        self.net = nn.Sequential(*[
            ConvBlock(
                channels[i-1] if i > 0 else in_channels,
                channels[i],
                kernel_size=kernel_size,
                dilation=2**i,
                final=(i == len(channels)-1)
            )
            for i in range(len(channels))
        ])
        
    def forward(self, x):
        return self.net(x)

## from .dilated_conv import DilatedConvEncoder

def generate_continuous_mask(B, T, n=5, l=0.1):
    res = torch.full((B, T), True, dtype=torch.bool)
    if isinstance(n, float):
        n = int(n * T)
    n = max(min(n, T // 2), 1)
    
    if isinstance(l, float):
        l = int(l * T)
    l = max(l, 1)
    
    for i in range(B):
        for _ in range(n):
            t = np.random.randint(T-l+1)
            res[i, t:t+l] = False
    return res

def generate_binomial_mask(B, T, p=0.5):
    return torch.from_numpy(np.random.binomial(1, p, size=(B, T))).to(torch.bool)

class TSEncoder(nn.Module):
    def __init__(self, input_dims, output_dims, hidden_dims=64, depth=10, mask_mode='binomial'):
        super().__init__()
        self.input_dims = input_dims
        self.output_dims = output_dims
        self.hidden_dims = hidden_dims
        self.mask_mode = mask_mode
        self.input_fc = nn.Linear(input_dims, hidden_dims)
        self.feature_extractor = DilatedConvEncoder(
            hidden_dims,
            [hidden_dims] * depth + [output_dims],
            kernel_size=3
        )
        self.repr_dropout = nn.Dropout(p=0.1)
        
    def forward(self, x, mask=None):  # x: B x T x input_dims
        nan_mask = ~x.isnan().any(axis=-1)
        x[~nan_mask] = 0
        x = self.input_fc(x)  # B x T x Ch
        
        # generate & apply mask
        if mask is None:
            if self.training:
                mask = self.mask_mode
            else:
                mask = 'all_true'
        
        if mask == 'binomial':
            mask = generate_binomial_mask(x.size(0), x.size(1)).to(x.device)
        elif mask == 'continuous':
            mask = generate_continuous_mask(x.size(0), x.size(1)).to(x.device)
        elif mask == 'all_true':
            mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool)
        elif mask == 'all_false':
            mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool)
        elif mask == 'mask_last':
            mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool)
            mask[:, -1] = False
        
        mask &= nan_mask
        x[~mask] = 0
        
        # conv encoder
        x = x.transpose(1, 2)  # B x Ch x T
        x = self.repr_dropout(self.feature_extractor(x))  # B x Co x T
        x = x.transpose(1, 2)  # B x T x Co
        
        return x


class TSDecoder(nn.Module):
    '''
    an empirical decoder, aiming at the last step

    RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
    '''
    def __init__(self, n_in: int = 8, n_out: int = 1, input_size: int = 3, target_size: int = 3, dec_mode: str = 'fusing2'):
        #n_in=timesteps, n_out=prediction_horizon, input_size=n_timeseries, target_size=n_timeseries
        super().__init__()
        self.dec_mode = dec_mode
        self.n_out = n_out
        self.target_size = target_size
        if dec_mode == 'fusing1': ### assuming input_size == target_size by default
            self.fc = nn.Linear(n_in, n_out)#, bias=True
        elif dec_mode == 'fusing2':
            self.fc = nn.Linear(n_in * input_size, n_out * target_size)
        else:
            raise NotImplementedError

    def forward(self, x):  # x: batch_size x timesteps x input_size
        batch_size = x.size(0)

        if self.dec_mode == 'slicing':
            output = x[:, -self.n_out:, :self.target_size] # the worst 
        elif self.dec_mode == 'fusing1':
            output = self.fc(x.permute(0,2,1)).permute(0,2,1)[:, :, :self.target_size]
        elif self.dec_mode == 'fusing2':# view, reshape
            output = self.fc(x.reshape(batch_size, -1)).view(-1, self.n_out, self.target_size)

        return output


class TS2Vec(nn.Module):
    '''
    my TS2Vec for simple TSF
    '''
    def __init__(self, n_in: int = 8, n_out: int = 1, input_size: int = 3, target_size: int = 3, hidden_dims: int = 64, depth: int = 10, mask_mode: str = 'binomial', dec_mode: str = 'fusing2'):
        super().__init__()

        self.encoder = TSEncoder(input_dims=input_size, output_dims=target_size, hidden_dims=hidden_dims, depth=depth, mask_mode=mask_mode)
        self.decoder = TSDecoder(n_in=n_in, n_out=n_out, input_size=input_size, target_size=target_size, dec_mode=dec_mode)

    def forward(self, x):  # x: batch_size x timesteps x input_size
        return self.decoder(self.encoder(x))


##################
##### SCINet #####
##################

class Splitting(nn.Module):
    def __init__(self):
        super(Splitting, self).__init__()

    def even(self, x):
        return x[:, ::2, :]

    def odd(self, x):
        return x[:, 1::2, :]

    def forward(self, x):
        '''Returns the odd and even part'''
        return (self.even(x), self.odd(x))


class Interactor(nn.Module):
    def __init__(self, in_planes, splitting=True,
                 kernel = 5, dropout=0.5, groups = 1, hidden_size = 1, INN = True):
        super(Interactor, self).__init__()
        self.modified = INN
        self.kernel_size = kernel
        self.dilation = 1
        self.dropout = dropout
        self.hidden_size = hidden_size
        self.groups = groups
        if self.kernel_size % 2 == 0:
            pad_l = self.dilation * (self.kernel_size - 2) // 2 + 1 #by default: stride==1 
            pad_r = self.dilation * (self.kernel_size) // 2 + 1 #by default: stride==1 

        else:
            pad_l = self.dilation * (self.kernel_size - 1) // 2 + 1 # we fix the kernel size of the second layer as 3.
            pad_r = self.dilation * (self.kernel_size - 1) // 2 + 1
        self.splitting = splitting
        self.split = Splitting()

        modules_P = []
        modules_U = []
        modules_psi = []
        modules_phi = []
        prev_size = 1

        size_hidden = self.hidden_size
        modules_P += [
            nn.ReplicationPad1d((pad_l, pad_r)),

            nn.Conv1d(in_planes * prev_size, int(in_planes * size_hidden),
                      kernel_size=self.kernel_size, dilation=self.dilation, stride=1, groups= self.groups),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),

            nn.Dropout(self.dropout),
            nn.Conv1d(int(in_planes * size_hidden), in_planes,
                      kernel_size=3, stride=1, groups= self.groups),
            nn.Tanh()
        ]
        modules_U += [
            nn.ReplicationPad1d((pad_l, pad_r)),
            nn.Conv1d(in_planes * prev_size, int(in_planes * size_hidden),
                      kernel_size=self.kernel_size, dilation=self.dilation, stride=1, groups= self.groups),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(self.dropout),
            nn.Conv1d(int(in_planes * size_hidden), in_planes,
                      kernel_size=3, stride=1, groups= self.groups),
            nn.Tanh()
        ]

        modules_phi += [
            nn.ReplicationPad1d((pad_l, pad_r)),
            nn.Conv1d(in_planes * prev_size, int(in_planes * size_hidden),
                      kernel_size=self.kernel_size, dilation=self.dilation, stride=1, groups= self.groups),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(self.dropout),
            nn.Conv1d(int(in_planes * size_hidden), in_planes,
                      kernel_size=3, stride=1, groups= self.groups),
            nn.Tanh()
        ]
        modules_psi += [
            nn.ReplicationPad1d((pad_l, pad_r)),
            nn.Conv1d(in_planes * prev_size, int(in_planes * size_hidden),
                      kernel_size=self.kernel_size, dilation=self.dilation, stride=1, groups= self.groups),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(self.dropout),
            nn.Conv1d(int(in_planes * size_hidden), in_planes,
                      kernel_size=3, stride=1, groups= self.groups),
            nn.Tanh()
        ]
        self.phi = nn.Sequential(*modules_phi)
        self.psi = nn.Sequential(*modules_psi)
        self.P = nn.Sequential(*modules_P)
        self.U = nn.Sequential(*modules_U)

    def forward(self, x):
        if self.splitting:
            (x_even, x_odd) = self.split(x)
        else:
            (x_even, x_odd) = x

        if self.modified:
            x_even = x_even.permute(0, 2, 1)
            x_odd = x_odd.permute(0, 2, 1)
#             print(x_odd.size(), x_even.size())
            d = x_odd.mul(torch.exp(self.phi(x_even)))
            c = x_even.mul(torch.exp(self.psi(x_odd)))

            x_even_update = c + self.U(d)
            x_odd_update = d - self.P(c)

            return (x_even_update, x_odd_update)

        else:
            x_even = x_even.permute(0, 2, 1)
            x_odd = x_odd.permute(0, 2, 1)

            d = x_odd - self.P(x_even)
            c = x_even + self.U(d)

            return (c, d)


class InteractorLevel(nn.Module):
    def __init__(self, in_planes, kernel, dropout, groups , hidden_size, INN):
        super(InteractorLevel, self).__init__()
        self.level = Interactor(in_planes = in_planes, splitting=True,
                 kernel = kernel, dropout=dropout, groups = groups, hidden_size = hidden_size, INN = INN)

    def forward(self, x):
        (x_even_update, x_odd_update) = self.level(x)
        return (x_even_update, x_odd_update)

class LevelSCINet(nn.Module):
    def __init__(self,in_planes, kernel_size, dropout, groups, hidden_size, INN):
        super(LevelSCINet, self).__init__()
        self.interact = InteractorLevel(in_planes= in_planes, kernel = kernel_size, dropout = dropout, groups =groups , hidden_size = hidden_size, INN = INN)

    def forward(self, x):
        (x_even_update, x_odd_update) = self.interact(x)
        return x_even_update.permute(0, 2, 1), x_odd_update.permute(0, 2, 1) #even: B, T, D odd: B, T, D

class SCINet_Tree(nn.Module):
    def __init__(self, in_planes, current_level, kernel_size, dropout, groups, hidden_size, INN):
        super().__init__()
        self.current_level = current_level


        self.workingblock = LevelSCINet(
            in_planes = in_planes,
            kernel_size = kernel_size,
            dropout = dropout,
            groups= groups,
            hidden_size = hidden_size,
            INN = INN)


        if current_level!=0:
            self.SCINet_Tree_odd=SCINet_Tree(in_planes, current_level-1, kernel_size, dropout, groups, hidden_size, INN)
            self.SCINet_Tree_even=SCINet_Tree(in_planes, current_level-1, kernel_size, dropout, groups, hidden_size, INN)
    
    def zip_up_the_pants(self, even, odd):
        even = even.permute(1, 0, 2)
        odd = odd.permute(1, 0, 2) #L, B, D
        even_len = even.shape[0]
        odd_len = odd.shape[0]
        mlen = min((odd_len, even_len))
        _ = []
        for i in range(mlen):
            _.append(even[i].unsqueeze(0))
            _.append(odd[i].unsqueeze(0))
        if odd_len < even_len: 
            _.append(even[-1].unsqueeze(0))
        return torch.cat(_,0).permute(1,0,2) #B, L, D
        
    def forward(self, x):
        x_even_update, x_odd_update= self.workingblock(x)
        # We recursively reordered these sub-series. You can run the ./utils/recursive_demo.py to emulate this procedure. 
        if self.current_level ==0:
            return self.zip_up_the_pants(x_even_update, x_odd_update)
        else:
            return self.zip_up_the_pants(self.SCINet_Tree_even(x_even_update), self.SCINet_Tree_odd(x_odd_update))

class EncoderTree(nn.Module):
    def __init__(self, in_planes,  num_levels, kernel_size, dropout, groups, hidden_size, INN):
        super().__init__()
        self.levels=num_levels
        self.SCINet_Tree = SCINet_Tree(
            in_planes = in_planes,
            current_level = num_levels-1,
            kernel_size = kernel_size,
            dropout =dropout ,
            groups = groups,
            hidden_size = hidden_size,
            INN = INN)
        
    def forward(self, x):

        x= self.SCINet_Tree(x)

        return x

class SCINet(nn.Module):
    def __init__(self, n_in = 8, n_out = 1, input_size = 3, target_size = 3, hid_size = 1, num_stacks = 2,
                num_levels = 3, num_decoder_layer = 2, concat_len = 0, groups = 1, kernel = 5, dropout = 0.,
                single_step_output_One = 0, n_in_seg = 0, positionalE = False, modified = True, RIN=False):
        super(SCINet, self).__init__()

        self.input_size = input_size # input_dim
        self.n_in = n_in # input_len
        self.n_out = n_out # output_len
        self.hidden_size = hid_size # 'hidden channel of module' 
        # H, EXPANSION RATE, 0.0625, type=float, help='hidden channel scale of module'
        self.num_levels = num_levels # 2
        self.groups = groups
        self.modified = modified
        self.kernel_size = kernel
        self.dropout = dropout
        self.single_step_output_One = single_step_output_One
        self.concat_len = concat_len
        self.pe = positionalE
        self.RIN=RIN
        self.num_decoder_layer = num_decoder_layer

        self.blocks1 = EncoderTree(
            in_planes=self.input_size,
            num_levels = self.num_levels,
            kernel_size = self.kernel_size,
            dropout = self.dropout,
            groups = self.groups,
            hidden_size = self.hidden_size,
            INN = modified)

        if num_stacks == 2: # we only implement two stacks at most.
            self.blocks2 = EncoderTree(
            in_planes=self.input_size,
            num_levels = self.num_levels,
            kernel_size = self.kernel_size,
            dropout = self.dropout,
            groups = self.groups,
            hidden_size = self.hidden_size,
            INN = modified)

        self.stacks = num_stacks

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
        
        self.n_mid = n_in if num_stacks==2 else n_out
        self.projection1 = nn.Conv1d(self.n_in, self.n_mid, kernel_size=1, stride=1, bias=False)
        # self.projection1 = nn.Conv1d(self.input_len, self.output_len, kernel_size=1, stride=1, bias=False)
        self.div_projection = nn.ModuleList()
        self.overlap_len = self.n_in//4
        self.div_len = self.n_in//6

        if self.num_decoder_layer > 1:
            # self.projection1 = nn.Linear(self.n_in, self.n_out)
            self.projection1 = TSDecoder(n_in=n_in, n_out=self.n_mid, input_size=input_size, target_size=target_size)#, dec_mode='fusing2'
            for layer_idx in range(self.num_decoder_layer-1):
                div_projection = nn.ModuleList()
                for i in range(6):
                    lens = min(i*self.div_len+self.overlap_len,self.n_in) - i*self.div_len
                    div_projection.append(nn.Linear(lens, self.div_len))
                self.div_projection.append(div_projection)

        if self.single_step_output_One: # only output the N_th timestep.
            if self.stacks == 2:
                if self.concat_len:
                    self.projection2 = nn.Conv1d(self.concat_len + self.n_mid, 1,
                                                kernel_size = 1, bias = False)
                else:
                    self.projection2 = nn.Conv1d(self.n_in + self.n_mid, 1,
                                                kernel_size = 1, bias = False)
        else: # output the N timesteps.
            if self.stacks == 2:
                if self.concat_len:
                    self.projection2 = nn.Conv1d(self.concat_len + self.n_mid, self.n_out,
                                                kernel_size = 1, bias = False)
                else:
                    self.projection2 = nn.Conv1d(self.n_in + self.n_mid, self.n_out,
                                                kernel_size = 1, bias = False)

        # For positional encoding
        self.pe_hidden_size = input_size
        if self.pe_hidden_size % 2 == 1:
            self.pe_hidden_size += 1
    
        num_timescales = self.pe_hidden_size // 2
        max_timescale = 10000.0
        min_timescale = 1.0

        log_timescale_increment = (
                math.log(float(max_timescale) / float(min_timescale)) /
                max(num_timescales - 1, 1))
        temp = torch.arange(num_timescales, dtype=torch.float32)
        inv_timescales = min_timescale * torch.exp(
            torch.arange(num_timescales, dtype=torch.float32) *
            -log_timescale_increment)
        self.register_buffer('inv_timescales', inv_timescales)

        ### RIN Parameters ###
        if self.RIN:
            self.affine_weight = nn.Parameter(torch.ones(1, 1, input_size))
            self.affine_bias = nn.Parameter(torch.zeros(1, 1, input_size))
    
    def get_position_encoding(self, x):
        max_length = x.size()[1]
        position = torch.arange(max_length, dtype=torch.float32, device=x.device)  # tensor([0., 1., 2., 3., 4.], device='cuda:0')
        temp1 = position.unsqueeze(1)  # 5 1
        temp2 = self.inv_timescales.unsqueeze(0)  # 1 256
        scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0)  # 5 256
        signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)  #[T, C]
        signal = F.pad(signal, (0, 0, 0, self.pe_hidden_size % 2))
        signal = signal.view(1, max_length, self.pe_hidden_size)
    
        return signal

    def forward(self, x):
        assert self.n_in % (np.power(2, self.num_levels)) == 0 # evenly divided the input length into two parts. (e.g., 32 -> 16 -> 8 -> 4 for 3 levels)
        if self.pe:
            pe = self.get_position_encoding(x)
            if pe.shape[2] > x.shape[2]:
                x += pe[:, :, :-1]
            else:
                x += self.get_position_encoding(x)

        ### activated when RIN flag is set ###
        if self.RIN:
            print('/// RIN ACTIVATED ///\r',end='')
            means = x.mean(1, keepdim=True).detach()
            #mean
            x = x - means
            #var
            stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x /= stdev
            # affine
            # print(x.shape,self.affine_weight.shape,self.affine_bias.shape)
            x = x * self.affine_weight + self.affine_bias

        # the first stack
        res1 = x
        x = self.blocks1(x)
        x += res1
        if self.num_decoder_layer == 1:
            x = self.projection1(x)
        else:
            x = x.permute(0,2,1)
            for div_projection in self.div_projection:
                output = torch.zeros(x.shape,dtype=x.dtype)#.cuda()
                for i, div_layer in enumerate(div_projection):
                    div_x = x[:,:,i*self.div_len:min(i*self.div_len+self.overlap_len,self.n_in)]
                    output[:,:,i*self.div_len:(i+1)*self.div_len] = div_layer(div_x)
                x = output
            # x = self.projection1(x)
            # x = x.permute(0,2,1)
            x = self.projection1(x)

        if self.stacks == 1:
            ### reverse RIN ###
            if self.RIN:
                x = x - self.affine_bias
                x = x / (self.affine_weight + 1e-10)
                x = x * stdev
                x = x + means

            return x

        elif self.stacks == 2:
            MidOutPut = x
            if self.concat_len:
                x = torch.cat((res1[:, -self.concat_len:,:], x), dim=1)
            else:
                x = torch.cat((res1, x), dim=1)

            # the second stack
            res2 = x
            x = self.blocks2(x)
            x += res2
            x = self.projection2(x)
            
            ### Reverse RIN ###
            if self.RIN:
                MidOutPut = MidOutPut - self.affine_bias
                MidOutPut = MidOutPut / (self.affine_weight + 1e-10)
                MidOutPut = MidOutPut * stdev
                MidOutPut = MidOutPut + means

            if self.RIN:
                x = x - self.affine_bias
                x = x / (self.affine_weight + 1e-10)
                x = x * stdev
                x = x + means

            return x#, MidOutPut


###################
##### FNN/MLP #####
###################

class FNN(nn.Module):
    '''
    Feedforward Neural Networks
    ---------------------------
    （多层、全连接、）前馈神经网络
    refer: 
    1. _BlockRNNModule **
    in darts-master\darts\models\forecasting\block_rnn_model.py
    '''
    def __init__(self, n_in=8, n_out=1, input_size=3, target_size=3, activation='Softplus', fc_layers=[200,300,200,100]):
        super(FNN, self).__init__() #, dropout=0.  #[512, 256, 128, 64] # [128,256,128,64]
        
        self.in_len = n_in
        self.out_len = n_out
        self.input_size = input_size
        self.target_size = target_size
        fc_layers = [] if fc_layers is None else fc_layers
        
        last = input_size * n_in
        feats = []
        for feature in fc_layers + [n_out * target_size]: # 'fusing2'
            feats.append(nn.Linear(last, feature))
            feats.append(getattr(nn, activation)())
            last = feature
        self.fc = nn.Sequential(*feats)

    def forward(self, x):
        # data is of size (batch_size, n_in, input_size)
        batch_size = x.size(0)
        
        predictions = self.fc(x.reshape(batch_size, self.in_len * self.input_size))
        
        predictions = predictions.view(batch_size, self.out_len, self.target_size) #reshape
        # predictions is of size (batch_size, n_out, target_size)
        return predictions



###################
##### N-BEATS #####
###################
from typing import NewType, Union, List, Optional, Tuple
from enum import Enum

class _GType(Enum):
    GENERIC = 1
    TREND = 2
    SEASONALITY = 3


GTypes = NewType('GTypes', _GType)


class TrendGenerator(nn.Module):

    def __init__(self,
                 expansion_coefficient_dim,
                 n_out):
        super(TrendGenerator, self).__init__()

        # basis is of size (expansion_coefficient_dim, n_out)
        basis = torch.stack([(torch.arange(n_out) / n_out)**i for i in range(expansion_coefficient_dim)], 
                            dim=1).T

        self.basis = nn.Parameter(basis, requires_grad=False)


    def forward(self, x):
        return torch.matmul(x, self.basis)


class SeasonalityGenerator(nn.Module):

    def __init__(self,
                 n_out):
        super(SeasonalityGenerator, self).__init__()
        half_minus_one = int(n_out / 2 - 1)
        cos_vectors = [torch.cos(torch.arange(n_out) * 2 * np.pi * i) for i in range(1, half_minus_one + 1)]
        sin_vectors = [torch.sin(torch.arange(n_out) * 2 * np.pi * i) for i in range(1, half_minus_one + 1)]
        
        # basis is of size (2 * int(n_out / 2 - 1) + 1, n_out)
        basis = torch.stack([torch.ones(n_out)] + cos_vectors + sin_vectors, dim=1).T

        self.basis = nn.Parameter(basis, requires_grad=False)

    def forward(self, x):
        return torch.matmul(x, self.basis)


class NbeatsBlock(nn.Module):

    def __init__(self,
                 num_layers: int,
                 layer_width: int,
                 expansion_coefficient_dim: int,
                 n_in: int,
                 n_out: int,
                 g_type: GTypes):
        super(NbeatsBlock, self).__init__()

        self.num_layers = num_layers
        self.layer_width = layer_width
        self.n_out = n_out
        self.g_type = g_type
        self.relu = nn.ReLU()

        # fully connected stack before fork
        self.linear_layer_stack_list = [nn.Linear(n_in, layer_width)]
        self.linear_layer_stack_list += [nn.Linear(layer_width, layer_width) for _ in range(num_layers - 1)]
        self.fc_stack = nn.ModuleList(self.linear_layer_stack_list)

        # Fully connected layer producing forecast/backcast expansion coeffcients (waveform generator parameters).
        # The coefficients are emitted for each parameter of the likelihood.
        if g_type == _GType.SEASONALITY:
            self.backcast_linear_layer = nn.Linear(layer_width, 2 * int(n_in / 2 - 1) + 1)
            self.forecast_linear_layer = nn.Linear(layer_width, (2 * int(n_out / 2 - 1) + 1))
        else:
            self.backcast_linear_layer = nn.Linear(layer_width, expansion_coefficient_dim)
            self.forecast_linear_layer = nn.Linear(layer_width, expansion_coefficient_dim)

        # waveform generator functions
        if g_type == _GType.GENERIC:
            self.backcast_g = nn.Linear(expansion_coefficient_dim, n_in)
            self.forecast_g = nn.Linear(expansion_coefficient_dim, n_out)
        elif g_type == _GType.TREND:
            self.backcast_g = TrendGenerator(expansion_coefficient_dim, n_in)
            self.forecast_g = TrendGenerator(expansion_coefficient_dim, n_out)
        elif g_type == _GType.SEASONALITY:
            self.backcast_g = SeasonalityGenerator(n_in)
            self.forecast_g = SeasonalityGenerator(n_out)
        else:
            raise_log(ValueError("g_type not supported"), logger)

    def forward(self, x):
        batch_size = x.shape[0]

        # fully connected layer stack
        for layer in self.linear_layer_stack_list:
            x = self.relu(layer(x))

        # forked linear layers producing waveform generator parameters
        theta_backcast = self.backcast_linear_layer(x)
        theta_forecast = self.forecast_linear_layer(x)

        # set the expansion coefs in last dimension for the forecasts
        theta_forecast = theta_forecast.view(batch_size, -1)

        # waveform generator applications (project the expansion coefs onto basis vectors)
        x_hat = self.backcast_g(theta_backcast)
        y_hat = self.forecast_g(theta_forecast)

        # Set the distribution parameters as the last dimension
        y_hat = y_hat.reshape(x.shape[0], self.n_out)

        return x_hat, y_hat


class NbeatsStack(nn.Module):

    def __init__(self,
                 num_blocks: int,
                 num_layers: int,
                 layer_width: int,
                 expansion_coefficient_dim: int,
                 n_in: int,
                 n_out: int,
                 g_type: GTypes,
                 ):
        super(NbeatsStack, self).__init__()

        self.n_in = n_in
        self.n_out = n_out

        if g_type == _GType.GENERIC:
            self.blocks_list = [
                NbeatsBlock(num_layers, layer_width, 
                       expansion_coefficient_dim, n_in, 
                       n_out, g_type)
                for _ in range(num_blocks)
            ]
        else:
            # same block instance is used for weight sharing
            interpretable_block = NbeatsBlock(num_layers, layer_width,
                                         expansion_coefficient_dim, n_in, 
                                         n_out, g_type)
            self.blocks_list = [interpretable_block] * num_blocks

        self.blocks = nn.ModuleList(self.blocks_list)

    def forward(self, x):
        # One forecast vector per parameter in the distribution
        stack_forecast = torch.zeros(x.shape[0], 
                                     self.n_out, 
                                     device=x.device, 
                                     dtype=x.dtype)

        for block in self.blocks_list:
            # pass input through block
            x_hat, y_hat = block(x)

            # add block forecast to stack forecast
            stack_forecast = stack_forecast + y_hat

            # subtract backcast from input to produce residual
            x = x - x_hat

        stack_residual = x

        return stack_residual, stack_forecast


class NBEATS(nn.Module):
    '''
    N-BEATS
    -------
    refer: 
    1. _TrendGenerator, _SeasonalityGenerator, _Block, _Stack, _NBEATSModule ****
    in darts-master\darts\models\forecasting\nbeats.py
    '''
    def __init__(self, 
                 n_in: int = 8,
                 n_out: int = 1,
                 input_size: int = 3,
                 target_size: int = 3,
                 generic_architecture: bool = True,
                 num_stacks: int = 20, # 30
                 num_blocks: int = 1,
                 num_layers: int = 4,
                 layer_widths: Union[int, List[int]] = 128, # 256
                 expansion_coefficient_dim: int = 5,
                 trend_polynomial_degree: int = 2,
                 ):
        super(NBEATS, self).__init__()
        if not generic_architecture:
            self.num_stacks = 2

        if isinstance(layer_widths, int):
            self.layer_widths = [layer_widths] * num_stacks
        
        self.input_size = input_size
        self.target_size = target_size
        self.n_in_multi = n_in * input_size
        self.n_out = n_out
        self.target_length = n_out * target_size # 'fusing2'

        if generic_architecture:
            self.stacks_list = [
                NbeatsStack(num_blocks,
                       num_layers,
                       self.layer_widths[i],
                       expansion_coefficient_dim,
                       self.n_in_multi,
                       self.target_length,
                       _GType.GENERIC)#.cuda()
                for i in range(num_stacks)
            ]
        else:
            num_stacks = 2
            trend_stack = NbeatsStack(num_blocks,
                                 num_layers,
                                 layer_widths[0],
                                 trend_polynomial_degree + 1,
                                 self.n_in_multi,
                                 self.target_length,
                                 _GType.TREND)#.cuda()
            seasonality_stack = NbeatsStack(num_blocks,
                                       num_layers,
                                       layer_widths[1],
                                       -1,
                                       self.n_in_multi,
                                       self.target_length,
                                       _GType.SEASONALITY)#.cuda()
            self.stacks_list = [trend_stack, seasonality_stack]

        self.stacks = nn.ModuleList(self.stacks_list)

        # setting the last backcast "branch" to be not trainable (without next block/stack, it doesn't need to be
        # backpropagated). Removing this lines would cause logtensorboard to crash, since no gradient is stored
        # on this params (the last block backcast is not part of the final output of the net).
        self.stacks_list[-1].blocks[-1].backcast_linear_layer.requires_grad_(False)
        self.stacks_list[-1].blocks[-1].backcast_g.requires_grad_(False)

    def forward(self, x):

        # if x1, x2,... y1, y2... is one multivariate ts containing x and y, and a1, a2... one covariate ts
        # we reshape into x1, y1, a1, x2, y2, a2... etc
        x = torch.reshape(x, (x.shape[0], self.n_in_multi, 1))
        # squeeze last dimension (because model is univariate)
        x = x.squeeze(dim=2)

        # One vector of length target_length per parameter in the distribution
        y = torch.zeros(x.shape[0], 
                        self.target_length,
                        device=x.device, 
                        dtype=x.dtype)

        for stack in self.stacks: #stacks_list:
            # compute stack output
            stack_residual, stack_forecast = stack(x)

            # add stack forecast to final output
            y = y + stack_forecast

            # set current stack residual as input for next stack
            x = stack_residual

        # In multivariate case, we get a result [x1_param1, x1_param2], [y1_param1, y1_param2], [x2..], [y2..], ... 
        # We want to reshape to original format. We also get rid of the covariates and keep only the target dimensions.
        # The covariates are by construction added as extra time series on the right side. So we need to get rid of this
        # right output (keeping only :self.target_size).
        y = y.view(y.shape[0], self.n_out, self.target_size)#[:, :, :self.target_size]

        return y


##################
##### N-HITS #####
##################
from functools import partial

# Cell
class IdentityBasis(nn.Module):
    def __init__(self, backcast_size: int, forecast_size: int, interpolation_mode: str):
        super(IdentityBasis, self).__init__()
        assert (interpolation_mode in ['linear','nearest']) or ('cubic' in interpolation_mode)
        self.forecast_size = forecast_size
        self.backcast_size = backcast_size
        self.interpolation_mode = interpolation_mode

    def forward(self, theta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

        backcast = theta[:, :self.backcast_size]
        knots = theta[:, self.backcast_size:]

        if self.interpolation_mode=='nearest':
            knots = knots[:,None,:]
            forecast = F.interpolate(knots, size=self.forecast_size, mode=self.interpolation_mode)
            forecast = forecast[:,0,:]
        elif self.interpolation_mode=='linear':
            knots = knots[:,None,:]
            forecast = F.interpolate(knots, size=self.forecast_size, mode=self.interpolation_mode, align_corners=True)
            forecast = forecast[:,0,:]
        elif 'cubic' in self.interpolation_mode:
            batch_size = len(backcast)
            knots = knots[:,None,None,:]
            forecast = torch.zeros((len(knots), self.forecast_size)).to(knots.device)
            n_batches = int(np.ceil(len(knots)/batch_size))
            for i in range(n_batches):
                forecast_i = F.interpolate(knots[i*batch_size:(i+1)*batch_size], size=self.forecast_size, mode='bicubic', align_corners=True)
                forecast[i*batch_size:(i+1)*batch_size] += forecast_i[:,0,0,:]

        return backcast, forecast

# Cell
def _init_weights(module, initialization):
    if type(module) == nn.Linear:
        if initialization == 'orthogonal':
            nn.init.orthogonal_(module.weight)
        elif initialization == 'he_uniform':
            nn.init.kaiming_uniform_(module.weight)
        elif initialization == 'he_normal':
            nn.init.kaiming_normal_(module.weight)
        elif initialization == 'glorot_uniform':
            nn.init.xavier_uniform_(module.weight)
        elif initialization == 'glorot_normal':
            nn.init.xavier_normal_(module.weight)
        elif initialization == 'lecun_normal':
            pass #nn.init.normal_(module.weight, 0.0, std=1/np.sqrt(module.weight.numel()))
        else:
            assert 1<0, f'Initialization {initialization} not found'

# Cell
ACTIVATIONS = ['ReLU',
               'Softplus',
               'Tanh',
               'SELU',
               'LeakyReLU',
               'PReLU',
               'Sigmoid']


class NhitsBlock(nn.Module):
    '''
    N-BEATS + pooling & interpolation
    '''
    def __init__(self,
                 num_layers: int, # n_layers
                 # layer_width: int, # 
                 n_mlp_units: list, # more general...
                 n_theta: int,
                 # expansion_coefficient_dim: int,
                 basis: nn.Module,
                 n_in: int, # n_time_in
                 n_out: int, # n_time_out
                 n_pool_kernel_size: int,
                 pooling_mode: str,
                 activation: str):
        super(NhitsBlock, self).__init__()

        self.num_layers = num_layers
        # self.layer_width = layer_width
        self.n_out = n_out
        self.n_pool_kernel_size = n_pool_kernel_size

        assert (pooling_mode in ['max','average'])
        n_time_in_pooled = int(np.ceil(n_in/n_pool_kernel_size))
        
        n_x = n_s_hidden = 0
        n_mlp_units = [n_time_in_pooled + (n_in+n_out)*n_x + n_s_hidden] + n_mlp_units

        if pooling_mode == 'max':
            self.pooling_layer = nn.MaxPool1d(kernel_size=self.n_pool_kernel_size,
                                              stride=self.n_pool_kernel_size, ceil_mode=True)
        elif pooling_mode == 'average':
            self.pooling_layer = nn.AvgPool1d(kernel_size=self.n_pool_kernel_size,
                                              stride=self.n_pool_kernel_size, ceil_mode=True)

        # self.relu = nn.ReLU()
        assert activation in ACTIVATIONS, f'{activation} is not in {ACTIVATIONS}'
        activ = getattr(nn, activation)()

        ## fully connected stack before fork
        # self.linear_layer_stack_list = [nn.Linear(n_in, layer_width)]
        # self.linear_layer_stack_list += [nn.Linear(layer_width, layer_width) for _ in range(num_layers - 1)]
        # self.fc_stack = nn.ModuleList(self.linear_layer_stack_list)

        hidden_layers = []
        for i in range(num_layers):
            hidden_layers.append(nn.Linear(in_features=n_mlp_units[i], out_features=n_mlp_units[i+1]))
            hidden_layers.append(activ)

        output_layer = [nn.Linear(in_features=n_mlp_units[-1], out_features=n_theta)]
        layers = hidden_layers + output_layer

        self.layers = nn.Sequential(*layers)

        ## Fully connected layer producing forecast/backcast expansion coeffcients (waveform generator parameters).
        ## The coefficients are emitted for each parameter of the likelihood.
        # self.backcast_linear_layer = nn.Linear(layer_width, expansion_coefficient_dim)
        # self.forecast_linear_layer = nn.Linear(layer_width, expansion_coefficient_dim)
        ## waveform generator functions
        # self.backcast_g = nn.Linear(expansion_coefficient_dim, n_in)
        # self.forecast_g = nn.Linear(expansion_coefficient_dim, n_out)

        self.basis = basis

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.pooling_layer(x)
        x = x.squeeze(1)

        batch_size = x.shape[0]

        ## fully connected layer stack
        # for layer in self.linear_layer_stack_list:
            # x = self.relu(layer(x))
        theta = self.layers(x)

        ## forked linear layers producing waveform generator parameters
        # theta_backcast = self.backcast_linear_layer(x)
        # theta_forecast = self.forecast_linear_layer(x)
        ## set the expansion coefs in last dimension for the forecasts
        # theta_forecast = theta_forecast.view(batch_size, -1)
        ## waveform generator applications (project the expansion coefs onto basis vectors)
        # x_hat = self.backcast_g(theta_backcast)
        # y_hat = self.forecast_g(theta_forecast)
        ## Set the distribution parameters as the last dimension
        # y_hat = y_hat.reshape(x.shape[0], self.n_out)

        backcast, forecast = self.basis(theta)

        return backcast, forecast # x_hat, y_hat


class NhitsStack(nn.Module):

    def __init__(self,
                 num_blocks: int,
                 num_layers: int,
                 # layer_width: int,
                 n_mlp_units: list, 
                 # expansion_coefficient_dim: int,
                 # basis: nn.Module,
                 n_in: int, # n_time_in
                 n_out: int, # n_time_out
                 n_pool_kernel_size: int,
                 n_freq_downsample: int,
                 pooling_mode: str,
                 interpolation_mode: str,
                 activation: str,
                 initialization: str,
                 shared_weights: bool,
                 ):
        super(NhitsStack, self).__init__()

        self.n_in = n_in
        self.n_out = n_out

        # self.blocks_list = [
            # NhitsBlock(num_layers,
                       # n_mlp_units,
                       # n_theta,
                       # basis,
                       # n_in,
                       # n_out,
                       # n_pool_kernel_size,
                       # pooling_mode,
                       # activation,
                       # )
            # for _ in range(num_blocks)
        # ]

        block_list = []
        for block_id in range(num_blocks):

            # Shared weights
            if shared_weights and block_id>0:
                nhits_block = block_list[-1]
            else:
                n_theta = (n_in + max(n_out//n_freq_downsample, 1) )
                basis = IdentityBasis(backcast_size=n_in,
                                      forecast_size=n_out,
                                      interpolation_mode=interpolation_mode)

                nhits_block = NhitsBlock(num_layers,
                                         n_mlp_units,
                                         n_theta,
                                         basis,
                                         n_in,
                                         n_out,
                                         n_pool_kernel_size,
                                         pooling_mode,
                                         activation)
            # Select type of evaluation and apply it to all layers of block
            init_function = partial(_init_weights, initialization=initialization)
            nhits_block.layers.apply(init_function)
            block_list.append(nhits_block)
        
        self.blocks_list = block_list
        self.blocks = nn.ModuleList(self.blocks_list)

    def forward(self, x):
        # One forecast vector per parameter in the distribution
        stack_forecast = torch.zeros(x.shape[0], 
                                     self.n_out, 
                                     device=x.device, 
                                     dtype=x.dtype)

        for block in self.blocks_list:
            # pass input through block
            x_hat, y_hat = block(x)

            # add block forecast to stack forecast
            stack_forecast = stack_forecast + y_hat

            # subtract backcast from input to produce residual
            x = x - x_hat

        stack_residual = x

        return stack_residual, stack_forecast


class NHITS(nn.Module):
    '''
    N-HITS
    ------
    refer: _init_weights
    1. _IdentityBasis, _NHITSBlock, create_stack, _NHITS ****
    in neuralforecast-main\neuralforecast\models\nhits\nhits.py
    '''
    def __init__(self, 
                 n_in: int = 8, # n_time_in
                 n_out: int = 1, # n_time_out
                 input_size: int = 3,
                 target_size: int = 3,
                 # generic_architecture: bool = True,
                 num_stacks: int = 20, #5, 20 # 30
                 num_blocks: int = 1, #3, 1 
                 num_layers: int = 4,
                 # layer_widths: Union[int, List[int]] = 128, # 256
                 n_mlp_units: list = [200,300,200,100],
                 # expansion_coefficient_dim: int = 5,
                 # trend_polynomial_degree: int = 2,
                 n_pool_kernel_size: int = 1,
                 n_freq_downsample: int = 1,
                 pooling_mode: str = 'max',
                 interpolation_mode: str = 'nearest', # 'linear', 'nearest', 'cubic'
                 activation: str = 'ReLU', # 
                 initialization: str = 'lecun_normal',
                 shared_weights: bool = False,
                 ):
        super(NHITS, self).__init__()

        # if isinstance(layer_widths, int):
            # self.layer_widths = [layer_widths] * num_stacks
        self.n_mlp_units = [n_mlp_units] * num_stacks

        self.input_size = input_size
        self.target_size = target_size
        self.n_in_multi = n_in * input_size # Insample size = n_time_in * output_size
        self.n_out = n_out # Forecast horizon.
        self.target_length = n_out * target_size # 'fusing2'

        self.stacks_list = [
            NhitsStack(num_blocks,
                       num_layers,
                       self.n_mlp_units[i], # self.layer_widths[i],
                       # basis, # 
                       self.n_in_multi,
                       self.target_length,
                       n_pool_kernel_size,
                       n_freq_downsample,
                       pooling_mode,
                       interpolation_mode,
                       activation,
                       initialization,
                       shared_weights,
                       )
            for i in range(num_stacks)
        ]

        self.stacks = nn.ModuleList(self.stacks_list)

        ## setting the last backcast "branch" to be not trainable (without next block/stack, it doesn't need to be
        ## backpropagated). Removing this lines would cause logtensorboard to crash, since no gradient is stored
        ## on this params (the last block backcast is not part of the final output of the net).
        # self.stacks_list[-1].blocks[-1].backcast_linear_layer.requires_grad_(False)
        # self.stacks_list[-1].blocks[-1].backcast_g.requires_grad_(False)

    def forward(self, x):

        # if x1, x2,... y1, y2... is one multivariate ts containing x and y, and a1, a2... one covariate ts
        # we reshape into x1, y1, a1, x2, y2, a2... etc
        x = torch.reshape(x, (x.shape[0], self.n_in_multi, 1))
        # squeeze last dimension (because model is univariate)
        x = x.squeeze(dim=2)

        # One vector of length target_length per parameter in the distribution
        y = torch.zeros(x.shape[0], 
                        self.target_length,
                        device=x.device, 
                        dtype=x.dtype)

        for stack in self.stacks: #stacks_list:
            # compute stack output
            stack_residual, stack_forecast = stack(x)

            # add stack forecast to final output
            y = y + stack_forecast

            # set current stack residual as input for next stack
            x = stack_residual

        # In multivariate case, we get a result [x1_param1, x1_param2], [y1_param1, y1_param2], [x2..], [y2..], ... 
        # We want to reshape to original format. We also get rid of the covariates and keep only the target dimensions.
        # The covariates are by construction added as extra time series on the right side. So we need to get rid of this
        # right output (keeping only :self.target_size).
        y = y.view(y.shape[0], self.n_out, self.target_size)#[:, :, :self.target_size]

        return y


###############
##### RNN #####
###############

class _RNN(nn.Module):
    '''
    Recurrent Neural Networks
    -------------------------
    refer: 
    1. _RNNModule *****
    in darts-master\darts\models\forecasting\rnn_model.py
    2. RNNModel **
    in gluonts\model\deep_factor\RNNModel.py
    '''
    def __init__(self, name='LSTM', n_in=8, n_out=1, input_size=3, target_size=3, hidden_layer_size=128, num_layers=3, bidirectional=True, dropout=0.):
        super(_RNN, self).__init__() # 'GRU'，'LSTM', 100, 1, False
        self.input_size = input_size
        self.D = 2 if bidirectional else 1

        self.rnn = getattr(nn, name)(input_size=input_size, 
                            hidden_size=hidden_layer_size, 
                            num_layers=num_layers, 
                            bidirectional=bidirectional, 
                            batch_first=True,
                            dropout=dropout)#.cuda()
        self.decoder = nn.Linear(self.D*hidden_layer_size, target_size)#.cuda() # target_size=1, univariate
        # self.fc = nn.Linear(n_in, n_out)#.cuda()
        ### refer TS2Vec
        self.tsdecoder = TSDecoder(n_in=n_in, n_out=n_out, input_size=input_size, target_size=target_size)#, dec_mode='fusing2'
        

    def forward(self, input_seq):
        batch_size = input_seq.shape[0] # len(input_seq)
        self.rnn.flatten_parameters()
        rnn_out, last_hidden_state = self.rnn(input_seq) #.reshape(batch_size, -1, self.input_size)
        predictions = self.decoder(rnn_out) #.view(batch_size, -1)
        
        return self.tsdecoder(predictions)
        #self.fc(predictions.permute(0,2,1)).permute(0,2,1)

###################################################
# use_cuda = torch.cuda.is_available()

# Cell
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, dropout=0.):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias_ih = nn.Parameter(torch.randn(4 * hidden_size))
        self.bias_hh = nn.Parameter(torch.randn(4 * hidden_size))
        self.dropout = dropout

    def forward(self, inputs, hidden):
        hx, cx = hidden[0].squeeze(0), hidden[1].squeeze(0)
        gates = (torch.matmul(inputs, self.weight_ih.t()) + self.bias_ih +
                         torch.matmul(hx, self.weight_hh.t()) + self.bias_hh)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)

        return hy, (hy, cy)

# Cell
class ResLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, dropout=0.):
        super(ResLSTMCell, self).__init__()
        self.register_buffer('input_size', torch.Tensor([input_size]))
        self.register_buffer('hidden_size', torch.Tensor([hidden_size]))
        self.weight_ii = nn.Parameter(torch.randn(3 * hidden_size, input_size))
        self.weight_ic = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
        self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
        self.bias_ii = nn.Parameter(torch.randn(3 * hidden_size))
        self.bias_ic = nn.Parameter(torch.randn(3 * hidden_size))
        self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
        self.weight_hh = nn.Parameter(torch.randn(1 * hidden_size, hidden_size))
        self.bias_hh = nn.Parameter(torch.randn(1 * hidden_size))
        self.weight_ir = nn.Parameter(torch.randn(hidden_size, input_size))
        self.dropout = dropout

    def forward(self, inputs, hidden):
        hx, cx = hidden[0].squeeze(0), hidden[1].squeeze(0)

        ifo_gates = (torch.matmul(inputs, self.weight_ii.t()) + self.bias_ii +
                                  torch.matmul(hx, self.weight_ih.t()) + self.bias_ih +
                                  torch.matmul(cx, self.weight_ic.t()) + self.bias_ic)
        ingate, forgetgate, outgate = ifo_gates.chunk(3, 1)

        cellgate = torch.matmul(hx, self.weight_hh.t()) + self.bias_hh

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        ry = torch.tanh(cy)

        if self.input_size == self.hidden_size:
            hy = outgate * (ry + inputs)
        else:
            hy = outgate * (ry + torch.matmul(inputs, self.weight_ir.t()))
        return hy, (hy, cy)

# Cell
class ResLSTMLayer(nn.Module):
    def __init__(self, input_size, hidden_size, dropout=0.):
        super(ResLSTMLayer, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.cell = ResLSTMCell(input_size, hidden_size, dropout=0.)

    def forward(self, inputs, hidden):
        inputs = inputs.unbind(0)
        outputs = []
        for i in range(len(inputs)):
                out, hidden = self.cell(inputs[i], hidden)
                outputs += [out]
        outputs = torch.stack(outputs)
        return outputs, hidden

# Cell
class AttentiveLSTMLayer(nn.Module):
    def __init__(self, input_size, hidden_size, dropout=0.0):
        super(AttentiveLSTMLayer, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        attention_hsize = hidden_size
        self.attention_hsize = attention_hsize

        self.cell = LSTMCell(input_size, hidden_size)
        self.attn_layer = nn.Sequential(nn.Linear(2 * hidden_size + input_size, attention_hsize),
                                        nn.Tanh(),
                                        nn.Linear(attention_hsize, 1))
        self.softmax = nn.Softmax(dim=0)
        self.dropout = dropout

    def forward(self, inputs, hidden):
        inputs = inputs.unbind(0)
        outputs = []

        for t in range(len(inputs)):
            # attention on windows
            hx, cx = (tensor.squeeze(0) for tensor in hidden)
            hx_rep = hx.repeat(len(inputs), 1, 1)
            cx_rep = cx.repeat(len(inputs), 1, 1)
            x = torch.cat((inputs, hx_rep, cx_rep), dim=-1)
            l = self.attn_layer(x)
            beta = self.softmax(l)
            context = torch.bmm(beta.permute(1, 2, 0),
                                inputs.permute(1, 0, 2)).squeeze(1)
            out, hidden = self.cell(context, hidden)
            outputs += [out]
        outputs = torch.stack(outputs)
        return outputs, hidden

# Cell
# class DRNN(nn.Module):
class DilRNN(nn.Module):
    '''
    Dilated Recurrent Neural Networks
    ---------------------------------
    refer: 
    1. DRNN ***
    in pytorch-dilated-rnn-master\drnn.py
    2. DRNN *****
    in n-hits-main\src\models\components\drnn.py
    '''
    def __init__(self, n_input, n_hidden, n_layers, dilations, dropout=0, cell_type='GRU', batch_first=True):
        # dropout=0, batch_first=True, bidirectional=False, cell_type='GRU'):
        super(DilRNN, self).__init__()

        assert n_layers==len(dilations) # [1 for i in range(n_layers)]
        self.dilations = dilations # [2 ** i for i in range(n_layers)]
        self.cell_type = cell_type
        self.batch_first = batch_first

        layers = []
        if self.cell_type == "GRU":
            cell = nn.GRU
        elif self.cell_type == "RNN":
            cell = nn.RNN
        elif self.cell_type == "LSTM":
            cell = nn.LSTM
        elif self.cell_type == "ResLSTM": # 不行
            cell = ResLSTMLayer
        elif self.cell_type == "AttentiveLSTM": # 也不行。。。
            cell = AttentiveLSTMLayer
        else:
            raise NotImplementedError

        for i in range(n_layers):
            if i == 0:
                c = cell(n_input, n_hidden, dropout=dropout)
            else:
                c = cell(n_hidden, n_hidden, dropout=dropout)
            layers.append(c)
        self.cells = nn.Sequential(*layers)

    def forward(self, inputs, hidden=None):
        if self.batch_first:
            inputs = inputs.transpose(0, 1)
        outputs = []
        for i, (cell, dilation) in enumerate(zip(self.cells, self.dilations)):
            if hidden is None:
                inputs, _ = self.drnn_layer(cell, inputs, dilation)
            else:
                inputs, hidden[i] = self.drnn_layer(cell, inputs, dilation, hidden[i])

            outputs.append(inputs[-dilation:])

        if self.batch_first:
            inputs = inputs.transpose(0, 1)
        return inputs, outputs

    def drnn_layer(self, cell, inputs, rate, hidden=None):
        n_steps = len(inputs)
        batch_size = inputs[0].size(0)
        hidden_size = cell.hidden_size

        inputs, dilated_steps = self._pad_inputs(inputs, n_steps, rate)
        dilated_inputs = self._prepare_inputs(inputs, rate)

        if hidden is None:
            dilated_outputs, hidden = self._apply_cell(dilated_inputs, cell, batch_size, rate, hidden_size)
        else:
            hidden = self._prepare_inputs(hidden, rate)
            dilated_outputs, hidden = self._apply_cell(dilated_inputs, cell, batch_size, rate, hidden_size,
                                                       hidden=hidden)

        splitted_outputs = self._split_outputs(dilated_outputs, rate)
        outputs = self._unpad_outputs(splitted_outputs, n_steps)

        return outputs, hidden

    def _apply_cell(self, dilated_inputs, cell, batch_size, rate, hidden_size, hidden=None):
        if hidden is None:
            hidden = torch.zeros(batch_size * rate, hidden_size,
                                 dtype=dilated_inputs.dtype,
                                 device=dilated_inputs.device)
            hidden = hidden.unsqueeze(0)

            if self.cell_type in ['LSTM', 'ResLSTM', 'AttentiveLSTM']:
                hidden = (hidden, hidden)

        dilated_outputs, hidden = cell(dilated_inputs, hidden) # compatibility hack

        return dilated_outputs, hidden

    def _unpad_outputs(self, splitted_outputs, n_steps):
        return splitted_outputs[:n_steps]

    def _split_outputs(self, dilated_outputs, rate):
        batchsize = dilated_outputs.size(1) // rate

        blocks = [dilated_outputs[:, i * batchsize: (i + 1) * batchsize, :] for i in range(rate)]

        interleaved = torch.stack((blocks)).transpose(1, 0).contiguous()
        interleaved = interleaved.view(dilated_outputs.size(0) * rate,
                                       batchsize,
                                       dilated_outputs.size(2))
        return interleaved

    def _pad_inputs(self, inputs, n_steps, rate):
        iseven = (n_steps % rate) == 0

        if not iseven:
            dilated_steps = n_steps // rate + 1

            zeros_ = torch.zeros(dilated_steps * rate - inputs.size(0),
                                 inputs.size(1),
                                 inputs.size(2),
                                 dtype=inputs.dtype,
                                 device=inputs.device)
            inputs = torch.cat((inputs, zeros_))
        else:
            dilated_steps = n_steps // rate

        return inputs, dilated_steps

    def _prepare_inputs(self, inputs, rate):
        dilated_inputs = torch.cat([inputs[j::rate, :, :] for j in range(rate)], 1)
        return dilated_inputs


class RNN(nn.Module):
    '''
    Recurrent Neural Networks
    -------------------------
    refer: 
    1. _RNNModule *****
    in darts-master\darts\models\forecasting\rnn_model.py
    2. RNNModel **
    in gluonts\model\deep_factor\RNNModel.py
    '''
    def __init__(self, name='LSTM', n_in=8, n_out=1, input_size=3, target_size=3, hidden_layer_size=128, n_layers=3, bidirectional=False, dropout=0, cell_type='GRU'):
        super(RNN, self).__init__() # 'GRU', 100, 1, False, num_layers=3, bidirectional=True, dropout=0.
        self.input_size = input_size
        self.D = 2 if bidirectional and name!='DilRNN' else 1
        self.name = name

        self.rnn = DilRNN(n_input=input_size,
                        n_hidden=hidden_layer_size,
                        n_layers=n_layers,
                        dilations=[2 ** i for i in range(n_layers)],
                        # bidirectional=bidirectional, 
                        batch_first=True,
                        dropout=dropout,
                        cell_type=cell_type
                       ) if name=='DilRNN' else getattr(nn, name)(
                            input_size=input_size, 
                            hidden_size=hidden_layer_size, 
                            num_layers=n_layers, 
                            bidirectional=bidirectional, 
                            batch_first=True,
                            dropout=dropout)
        self.decoder = nn.Linear(self.D*hidden_layer_size, input_size)
        #self.fc = nn.Linear(n_in, n_out)
        ### refer TS2Vec
        self.tsdecoder = TSDecoder(n_in=n_in, n_out=n_out, input_size=input_size, target_size=target_size)#, dec_mode='fusing2'

    def forward(self, input_seq): # input_seq: batch_size x timesteps x input_size
        batch_size = input_seq.shape[0]
        if self.name in ['RNN','LSTM','GRU']:
            self.rnn.flatten_parameters()
        rnn_out, last_hidden_state = self.rnn(input_seq)
        predictions = self.decoder(rnn_out)
        
        return self.tsdecoder(predictions)
        #self.fc(predictions.permute(0,2,1)).permute(0,2,1)



###############
##### ESN #####
###############

import re
import torch.sparse
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence

## reservoir
"""
This examples is not intended to be optimized. Its purpose is to show how to handle
big datasets with multiple sequences. The accuracy should be around 10%.
"""

def apply_permutation(tensor, permutation, dim=1):
    # type: (Tensor, Tensor, int) -> Tensor
    return tensor.index_select(dim, permutation)


class Reservoir(nn.Module):

    def __init__(self, mode, input_size, hidden_size, num_layers, leaking_rate,
                 spectral_radius, w_ih_scale,
                 density, bias=True, batch_first=False):
        super(Reservoir, self).__init__()
        self.mode = mode
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.leaking_rate = leaking_rate
        self.spectral_radius = spectral_radius
        self.w_ih_scale = w_ih_scale
        self.density = density
        self.bias = bias
        self.batch_first = batch_first

        self._all_weights = []
        for layer in range(num_layers):
            layer_input_size = input_size if layer == 0 else hidden_size

            w_ih = nn.Parameter(torch.Tensor(hidden_size, layer_input_size))
            w_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
            b_ih = nn.Parameter(torch.Tensor(hidden_size))
            layer_params = (w_ih, w_hh, b_ih)

            param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
            if bias:
                param_names += ['bias_ih_l{}{}']
            param_names = [x.format(layer, '') for x in param_names]

            for name, param in zip(param_names, layer_params):
                setattr(self, name, param)
            self._all_weights.append(param_names)

        self.reset_parameters()

    def _apply(self, fn):
        ret = super(Reservoir, self)._apply(fn)
        return ret

    def reset_parameters(self):
        weight_dict = self.state_dict()
        for key, value in weight_dict.items():
            if key == 'weight_ih_l0':
                nn.init.uniform_(value, -1, 1)
                value *= self.w_ih_scale[1:]
            elif re.fullmatch('weight_ih_l[^0]*', key):
                nn.init.uniform_(value, -1, 1)
            elif re.fullmatch('bias_ih_l[0-9]*', key):
                nn.init.uniform_(value, -1, 1)
                value *= self.w_ih_scale[0]
            elif re.fullmatch('weight_hh_l[0-9]*', key):
                w_hh = torch.Tensor(self.hidden_size * self.hidden_size)
                w_hh.uniform_(-1, 1)
                if self.density < 1:
                    zero_weights = torch.randperm(
                        int(self.hidden_size * self.hidden_size))
                    zero_weights = zero_weights[
                                   :int(
                                       self.hidden_size * self.hidden_size * (
                                                   1 - self.density))]
                    w_hh[zero_weights] = 0
                w_hh = w_hh.view(self.hidden_size, self.hidden_size)
                # abs_eigs = (torch.eig(w_hh)[0] ** 2).sum(1).sqrt()
                abs_eigs = torch.abs(torch.linalg.eigvals(w_hh))#.sum(1).sqrt()   ** 2  LA.norm
                weight_dict[key] = w_hh * (self.spectral_radius / torch.max(abs_eigs))

        self.load_state_dict(weight_dict)

    def check_input(self, input, batch_sizes):
        # type: (Tensor, Optional[Tensor]) -> None
        expected_input_dim = 2 if batch_sizes is not None else 3
        if input.dim() != expected_input_dim:
            raise RuntimeError(
                'input must have {} dimensions, got {}'.format(
                    expected_input_dim, input.dim()))
        if self.input_size != input.size(-1):
            raise RuntimeError(
                'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
                    self.input_size, input.size(-1)))

    def get_expected_hidden_size(self, input, batch_sizes):
        # type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
        if batch_sizes is not None:
            mini_batch = batch_sizes[0]
            mini_batch = int(mini_batch)
        else:
            mini_batch = input.size(0) if self.batch_first else input.size(1)
        expected_hidden_size = (self.num_layers, mini_batch, self.hidden_size)
        return expected_hidden_size

    def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
        # type: (Tensor, Tuple[int, int, int], str) -> None
        if hx.size() != expected_hidden_size:
            raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))

    def check_forward_args(self, input, hidden, batch_sizes):
        # type: (Tensor, Tensor, Optional[Tensor]) -> None
        self.check_input(input, batch_sizes)
        expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)

        self.check_hidden_size(hidden, expected_hidden_size)

    def permute_hidden(self, hx, permutation):
        # type: (Tensor, Optional[Tensor]) -> Tensor
        if permutation is None:
            return hx
        return apply_permutation(hx, permutation)

    def forward(self, input, hx=None):
        is_packed = isinstance(input, PackedSequence)
        if is_packed:
            input, batch_sizes, sorted_indices, unsorted_indices = input
            max_batch_size = int(batch_sizes[0])
        else:
            batch_sizes = None
            max_batch_size = input.size(0) if self.batch_first else input.size(1)
            sorted_indices = None
            unsorted_indices = None

        if hx is None:
            hx = input.new_zeros(self.num_layers, max_batch_size,
                                 self.hidden_size, requires_grad=False)
        else:
            # Each batch of the hidden state should match the input sequence that
            # the user believes he/she is passing in.
            hx = self.permute_hidden(hx, sorted_indices)

        flat_weight = None

        self.check_forward_args(input, hx, batch_sizes)
        func = AutogradReservoir(
            self.mode,
            self.input_size,
            self.hidden_size,
            num_layers=self.num_layers,
            batch_first=self.batch_first,
            train=self.training,
            variable_length=is_packed,
            flat_weight=flat_weight,
            leaking_rate=self.leaking_rate
        )
        output, hidden = func(input, self.all_weights, hx, batch_sizes)
        if is_packed:
            output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
        return output, self.permute_hidden(hidden, unsorted_indices)

    def extra_repr(self):
        s = '({input_size}, {hidden_size}'
        if self.num_layers != 1:
            s += ', num_layers={num_layers}'
        if self.bias is not True:
            s += ', bias={bias}'
        if self.batch_first is not False:
            s += ', batch_first={batch_first}'
        s += ')'
        return s.format(**self.__dict__)

    def __setstate__(self, d):
        super(Reservoir, self).__setstate__(d)
        self.__dict__.setdefault('_data_ptrs', [])
        if 'all_weights' in d:
            self._all_weights = d['all_weights']
        if isinstance(self._all_weights[0][0], str):
            return
        num_layers = self.num_layers
        self._all_weights = []
        for layer in range(num_layers):
            weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}']
            weights = [x.format(layer) for x in weights]
            if self.bias:
                self._all_weights += [weights]
            else:
                self._all_weights += [weights[:2]]

    @property
    def all_weights(self):
        return [[getattr(self, weight) for weight in weights] for weights in
                self._all_weights]


def AutogradReservoir(mode, input_size, hidden_size, num_layers=1,
                      batch_first=False, train=True,
                      batch_sizes=None, variable_length=False, flat_weight=None,
                      leaking_rate=1):
    if mode == 'RES_TANH':
        cell = ResTanhCell
    elif mode == 'RES_RELU':
        cell = ResReLUCell
    elif mode == 'RES_ID':
        cell = ResIdCell

    if variable_length:
        layer = (VariableRecurrent(cell, leaking_rate),)
    else:
        layer = (Recurrent(cell, leaking_rate),)

    func = StackedRNN(layer,
                      num_layers,
                      False,
                      train=train)

    def forward(input, weight, hidden, batch_sizes):
        if batch_first and batch_sizes is None:
            input = input.transpose(0, 1)

        nexth, output = func(input, hidden, weight, batch_sizes)

        if batch_first and not variable_length:
            output = output.transpose(0, 1)

        return output, nexth

    return forward


def Recurrent(inner, leaking_rate):
    def forward(input, hidden, weight, batch_sizes):
        output = []
        steps = range(input.size(0))
        for i in steps:
            hidden = inner(input[i], hidden, leaking_rate, *weight)
            # hack to handle LSTM
            output.append(hidden[0] if isinstance(hidden, tuple) else hidden)

        output = torch.cat(output, 0).view(input.size(0), *output[0].size())

        return hidden, output

    return forward


def VariableRecurrent(inner, leaking_rate):
    def forward(input, hidden, weight, batch_sizes):
        output = []
        input_offset = 0
        last_batch_size = batch_sizes[0]
        hiddens = []
        flat_hidden = not isinstance(hidden, tuple)
        if flat_hidden:
            hidden = (hidden,)
        for batch_size in batch_sizes:
            step_input = input[input_offset:input_offset + batch_size]
            input_offset += batch_size

            dec = last_batch_size - batch_size
            if dec > 0:
                hiddens.append(tuple(h[-dec:] for h in hidden))
                hidden = tuple(h[:-dec] for h in hidden)
            last_batch_size = batch_size

            if flat_hidden:
                hidden = (inner(step_input, hidden[0], leaking_rate, *weight),)
            else:
                hidden = inner(step_input, hidden, leaking_rate, *weight)

            output.append(hidden[0])
        hiddens.append(hidden)
        hiddens.reverse()

        hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens))
        assert hidden[0].size(0) == batch_sizes[0]
        if flat_hidden:
            hidden = hidden[0]
        output = torch.cat(output, 0)

        return hidden, output

    return forward


def StackedRNN(inners, num_layers, lstm=False, train=True):
    num_directions = len(inners)
    total_layers = num_layers * num_directions

    def forward(input, hidden, weight, batch_sizes):
        assert (len(weight) == total_layers)
        next_hidden = []
        all_layers_output = []

        for i in range(num_layers):
            all_output = []
            for j, inner in enumerate(inners):
                l = i * num_directions + j

                hy, output = inner(input, hidden[l], weight[l], batch_sizes)
                next_hidden.append(hy)
                all_output.append(output)

            input = torch.cat(all_output, input.dim() - 1)
            all_layers_output.append(input)

        all_layers_output = torch.cat(all_layers_output, -1)
        next_hidden = torch.cat(next_hidden, 0).view(
            total_layers, *next_hidden[0].size())

        return next_hidden, all_layers_output

    return forward


def ResTanhCell(input, hidden, leaking_rate, w_ih, w_hh, b_ih=None):
    hy_ = torch.tanh(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh))
    hy = (1 - leaking_rate) * hidden + leaking_rate * hy_
    return hy


def ResReLUCell(input, hidden, leaking_rate, w_ih, w_hh, b_ih=None):
    hy_ = F.relu(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh))
    hy = (1 - leaking_rate) * hidden + leaking_rate * hy_
    return hy


def ResIdCell(input, hidden, leaking_rate, w_ih, w_hh, b_ih=None):
    hy_ = F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh)
    hy = (1 - leaking_rate) * hidden + leaking_rate * hy_
    return hy


## utils

def prepare_target(target, seq_lengths, washout, batch_first=False):
    """ Preprocess target for offline training.

    Args:
        target (seq_len, batch, output_size): tensor containing
            the features of the target sequence.
        seq_lengths: list of lengths of each sequence in the batch.
        washout: number of initial timesteps during which output of the
            reservoir is not forwarded to the readout. One value per sample.
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False``

    Returns:
        tensor containing the features of the batch's sequences rolled out along
        one axis, minus the washouts and the padded values.
    """

    if batch_first:
        target = target.transpose(0, 1)
    n_sequences = target.size(1)
    target_dim = target.size(2)
    train_len = sum(torch.tensor(seq_lengths) - torch.tensor(washout)).item()

    new_target = torch.zeros(train_len, target_dim, device=target.device)

    idx = 0
    for s in range(n_sequences):
        batch_len = seq_lengths[s] - washout[s]
        new_target[idx:idx + batch_len, :] = target[washout[s]:seq_lengths[s], s, :]
        idx += batch_len

    return new_target


def washout_tensor(tensor, washout, seq_lengths, bidirectional=False, batch_first=False):
    tensor = tensor.transpose(0, 1) if batch_first else tensor.clone()
    if type(seq_lengths) == list:
        seq_lengths = seq_lengths.copy()
    if type(seq_lengths) == torch.Tensor:
        seq_lengths = seq_lengths.clone()

    for b in range(tensor.size(1)):
        if washout[b] > 0:
            tmp = tensor[washout[b]:seq_lengths[b], b].clone()
            tensor[:seq_lengths[b] - washout[b], b] = tmp
            tensor[seq_lengths[b] - washout[b]:, b] = 0
            seq_lengths[b] -= washout[b]

            if bidirectional:
                tensor[seq_lengths[b] - washout[b]:, b] = 0
                seq_lengths[b] -= washout[b]

    if type(seq_lengths) == list:
        max_len = max(seq_lengths)
    else:
        max_len = max(seq_lengths).item()

    return tensor[:max_len], seq_lengths


## esn
#from .reservoir import Reservoir
#from ..utils import washout_tensor

class _ESN(nn.Module):
    """ Applies an Echo State Network to an input sequence. Multi-layer Echo
    State Network is based on paper
    Deep Echo State Network (DeepESN): A Brief Survey - Gallicchio, Micheli 2017

    Args:
        input_size: The number of expected features in the input x.
        hidden_size: The number of features in the hidden state h.
        output_size: The number of expected features in the output y.
        num_layers: Number of recurrent layers. Default: 1
        nonlinearity: The non-linearity to use ['tanh'|'relu'|'id'].
            Default: 'tanh'
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False``
        leaking_rate: Leaking rate of reservoir's neurons. Default: 1
        spectral_radius: Desired spectral radius of recurrent weight matrix.
            Default: 0.9
        w_ih_scale: Scale factor for first layer's input weights (w_ih_l0). It
            can be a number or a tensor of size '1 + input_size' and first element
            is the bias' scale factor. Default: 1
        lambda_reg: Ridge regression's shrinkage parameter. Default: 1
        density: Recurrent weight matrix's density. Default: 1
        w_io: If 'True', then the network uses trainable input-to-output
            connections. Default: ``False``
        readout_training: Readout's traning algorithm ['gd'|'svd'|'cholesky'|'inv'].
            If 'gd', gradients are accumulated during backward
            pass. If 'svd', 'cholesky' or 'inv', the network will learn readout's
            parameters during the forward pass using ridge regression. The
            coefficients are computed using SVD, Cholesky decomposition or
            standard ridge regression formula. 'gd', 'cholesky' and 'inv'
            permit the usage of mini-batches to train the readout.
            If 'inv' and matrix is singular, pseudoinverse is used.
        output_steps: defines how the reservoir's output will be used by ridge
            regression method ['all', 'mean', 'last'].
            If 'all', the entire reservoir output matrix will be used.
            If 'mean', the mean of reservoir output matrix along the timesteps
            dimension will be used.
            If 'last', only the last timestep of the reservoir output matrix
            will be used.
            'mean' and 'last' are useful for classification tasks.

    Inputs: input, washout, h_0, target
        input (seq_len, batch, input_size): tensor containing the features of
            the input sequence. The input can also be a packed variable length
            sequence. See `torch.nn.utils.rnn.pack_padded_sequence`
        washout (batch): number of initial timesteps during which output of the
            reservoir is not forwarded to the readout. One value per batch's
            sample.
        h_0 (num_layers, batch, hidden_size): tensor containing
             the initial reservoir's hidden state for each element in the batch.
             Defaults to zero if not provided.

        target (seq_len*batch - washout*batch, output_size): tensor containing
            the features of the batch's target sequences rolled out along one
            axis, minus the washouts and the padded values. It is only needed
            for readout's training in offline mode. Use `prepare_target` to
            compute it.

    Outputs: output, h_n
        - output (seq_len, batch, hidden_size): tensor containing the output
        features (h_k) from the readout, for each k.
        - **h_n** (num_layers * num_directions, batch, hidden_size): tensor
          containing the reservoir's hidden state for k=seq_len.
    """

    def __init__(self, input_size, hidden_size, output_size, num_layers=1,
                 nonlinearity='tanh', batch_first=False, leaking_rate=1,
                 spectral_radius=0.9, w_ih_scale=1, lambda_reg=0, density=1,
                 w_io=False, readout_training='svd', output_steps='all'):
        super(_ESN, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        if nonlinearity == 'tanh':
            mode = 'RES_TANH'
        elif nonlinearity == 'relu':
            mode = 'RES_RELU'
        elif nonlinearity == 'id':
            mode = 'RES_ID'
        else:
            raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity))
        self.batch_first = batch_first
        self.leaking_rate = leaking_rate
        self.spectral_radius = spectral_radius
        if type(w_ih_scale) != torch.Tensor:
            self.w_ih_scale = torch.ones(input_size + 1)
            self.w_ih_scale *= w_ih_scale
        else:
            self.w_ih_scale = w_ih_scale

        self.lambda_reg = lambda_reg
        self.density = density
        self.w_io = w_io
        if readout_training in {'gd', 'svd', 'cholesky', 'inv'}:
            self.readout_training = readout_training
        else:
            raise ValueError("Unknown readout training algorithm '{}'".format(
                readout_training))

        self.reservoir = Reservoir(mode, input_size, hidden_size, num_layers,
                                   leaking_rate, spectral_radius,
                                   self.w_ih_scale, density,
                                   batch_first=batch_first)

        if w_io:
            self.readout = nn.Linear(input_size + hidden_size * num_layers,
                                     output_size)
        else:
            self.readout = nn.Linear(hidden_size * num_layers, output_size)
        if readout_training == 'offline':
            self.readout.weight.requires_grad = False

        if output_steps in {'all', 'mean', 'last'}:
            self.output_steps = output_steps
        else:
            raise ValueError("Unknown task '{}'".format(
                output_steps))

        self.XTX = None
        self.XTy = None
        self.X = None

    def forward(self, input, washout, h_0=None, target=None):
        with torch.no_grad():
            is_packed = isinstance(input, PackedSequence)

            output, hidden = self.reservoir(input, h_0)
            if is_packed:
                output, seq_lengths = pad_packed_sequence(output,
                                                          batch_first=self.batch_first)
            else:
                if self.batch_first:
                    seq_lengths = output.size(0) * [output.size(1)]
                else:
                    seq_lengths = output.size(1) * [output.size(0)]

            if self.batch_first:
                output = output.transpose(0, 1)
                # print(output.size())
                # print(seq_lengths)

            output, seq_lengths = washout_tensor(output, washout, seq_lengths)

            if self.w_io:
                if is_packed:
                    input, input_lengths = pad_packed_sequence(input,
                                                          batch_first=self.batch_first)
                else:
                    input_lengths = [input.size(0)] * input.size(1)

                if self.batch_first:
                    input = input.transpose(0, 1)

                input, _ = washout_tensor(input, washout, input_lengths)
                output = torch.cat([input, output], -1)

            if self.readout_training == 'gd' or target is None:
                with torch.enable_grad():
                    output = self.readout(output)

                    if is_packed:
                        for i in range(output.size(1)):
                            if seq_lengths[i] < output.size(0):
                                output[seq_lengths[i]:, i] = 0

                    if self.batch_first:
                        output = output.transpose(0, 1)

                    # Uncomment if you want packed output.
                    # if is_packed:
                    #     output = pack_padded_sequence(output, seq_lengths,
                    #                                   batch_first=self.batch_first)

                    return output, hidden

            else:
                batch_size = output.size(1)

                X = torch.ones(target.size(0), 1 + output.size(2), device=target.device)
                row = 0
                for s in range(batch_size):
                    if self.output_steps == 'all':
                        X[row:row + seq_lengths[s], 1:] = output[:seq_lengths[s],
                                                          s]
                        row += seq_lengths[s]
                    elif self.output_steps == 'mean':
                        X[row, 1:] = torch.mean(output[:seq_lengths[s], s], 0)
                        row += 1
                    elif self.output_steps == 'last':
                        X[row, 1:] = output[seq_lengths[s] - 1, s]
                        row += 1

                if self.readout_training == 'cholesky':
                    if self.XTX is None:
                        self.XTX = torch.mm(X.t(), X)
                        self.XTy = torch.mm(X.t(), target)
                    else:
                        self.XTX += torch.mm(X.t(), X)
                        self.XTy += torch.mm(X.t(), target)

                elif self.readout_training == 'svd':
                    # Scikit-Learn SVD solver for ridge regression.
                    U, s, V = torch.svd(X)
                    idx = s > 1e-15  # same default value as scipy.linalg.pinv
                    s_nnz = s[idx][:, None]
                    UTy = torch.mm(U.t(), target)
                    d = torch.zeros(s.size(0), 1, device=X.device)
                    d[idx] = s_nnz / (s_nnz ** 2 + self.lambda_reg)
                    d_UT_y = d * UTy
                    W = torch.mm(V, d_UT_y).t()

                    self.readout.bias = nn.Parameter(W[:, 0])
                    self.readout.weight = nn.Parameter(W[:, 1:])
                elif self.readout_training == 'inv':
                    self.X = X
                    if self.XTX is None:
                        self.XTX = torch.mm(X.t(), X)
                        self.XTy = torch.mm(X.t(), target)
                    else:
                        self.XTX += torch.mm(X.t(), X)
                        self.XTy += torch.mm(X.t(), target)

                return None, None

    def fit(self):
        if self.readout_training in {'gd', 'svd'}:
            return

        if self.readout_training == 'cholesky':
            W = torch.solve(self.XTy,
                           self.XTX + self.lambda_reg * torch.eye(
                               self.XTX.size(0), device=self.XTX.device))[0].t()
            self.XTX = None
            self.XTy = None

            self.readout.bias = nn.Parameter(W[:, 0])
            self.readout.weight = nn.Parameter(W[:, 1:])
        elif self.readout_training == 'inv':
            I = (self.lambda_reg * torch.eye(self.XTX.size(0))).to(
                self.XTX.device)
            A = self.XTX + I
            X_rank = torch.linalg.matrix_rank(A).item()

            if X_rank == self.X.size(0):
                W = torch.mm(torch.inverse(A), self.XTy).t()
            else:
                W = torch.mm(torch.pinverse(A), self.XTy).t()

            self.readout.bias = nn.Parameter(W[:, 0])
            self.readout.weight = nn.Parameter(W[:, 1:])

            self.XTX = None
            self.XTy = None

    def reset_parameters(self):
        self.reservoir.reset_parameters()
        self.readout.reset_parameters()


class ESN(nn.Module):
    '''
    Echo State Networks
    -------------------
    rely on: 
    1. ESN ***
    in pytorch-esn-master\torchesn\nn\echo_state_network.py
    '''
    def __init__(self, n_in=8, n_out=1, input_size=3, target_size=3, hidden_layer_size=5000, num_layers=1):#, bidirectional=True, dropout=0.,  name='LSTM',
        super(ESN, self).__init__()
        self.out_len = n_out
        self.target_size = target_size
        self.washout = [hidden_layer_size]

        self.esn = _ESN(input_size=input_size, 
                        hidden_size=hidden_layer_size, 
                        output_size=target_size,
                        num_layers=num_layers, 
                        # bidirectional=bidirectional, 
                        batch_first=True,
                        )#.cuda()  dropout=dropout
        self.fc = nn.Linear(hidden_layer_size, n_out * target_size)#.cuda() # 'fusing2'

    def forward(self, input_seq):
        batch_size = input_seq.shape[0] # len(input_seq)

        esn_out, last_hidden_state = self.esn(input_seq, self.washout*batch_size)
        
        predictions = self.fc(last_hidden_state.permute(1,0,2).reshape(batch_size,-1)) #.view(batch_size, -1)
        return predictions.view(batch_size, self.out_len, self.target_size)


##################
##### LSTNet #####
##################

class LSTNet(nn.Module):
    def __init__(self, n_in=8, n_out=1, input_size=3, target_size=3, hidRNN=100, hidCNN=100, hidSkip=5, CNN_kernel=6, skip=5, highway_window=5, output_fun='sigmoid'):
        super(LSTNet, self).__init__()
        # self.use_cuda = args.cuda
        self.P = n_in # args.window; #'window size'
        self.n_out = n_out
        self.m = input_size # data.m # n_timeseries
        self.target_size = target_size
        self.hidR = hidRNN # args.hidRNN; #'number of RNN hidden units'
        self.hidC = hidCNN # args.hidCNN;
        self.hidS = hidSkip # args.hidSkip;
        self.Ck = CNN_kernel #args.CNN_kernel;
        self.skip = skip #args.skip;
        self.pt = (self.P - self.Ck)//self.skip
        self.hw = highway_window #args.highway_window #'The window size of the highway component'
        self.conv1 = nn.Conv2d(1, self.hidC, kernel_size = (self.Ck, self.m))#;
        self.GRU1 = nn.GRU(self.hidC, self.hidR)#;
        self.dropout = nn.Dropout(p=0)#;#dropout applied to layers (0 = no dropout)
        if (self.skip > 0):
            self.GRUskip = nn.GRU(self.hidC, self.hidS)
            self.linear1 = nn.Linear(self.hidR + self.skip * self.hidS, target_size * n_out)
        else:
            self.linear1 = nn.Linear(self.hidR, target_size * n_out) # 'fusing2'
        if (self.hw > 0):
            self.highway = nn.Linear(self.hw, n_out)
        self.output = None;
        if (output_fun == 'sigmoid'):
            self.output = torch.sigmoid#;args.
        if (output_fun == 'tanh'):
            self.output = torch.tanh#;args.
 
    def forward(self, x):
        batch_size = x.size(0)

        #CNN
        c = x.view(-1, 1, self.P, self.m)
        c = F.relu(self.conv1(c))
        c = self.dropout(c)
        c = torch.squeeze(c, 3)
        
        # RNN 
        r = c.permute(2, 0, 1).contiguous()
        _, r = self.GRU1(r)
        r = self.dropout(torch.squeeze(r,0))

        #skip-rnn
        if (self.skip > 0):
            s = c[:,:, int(-self.pt * self.skip):].contiguous()
            s = s.view(batch_size, self.hidC, self.pt, self.skip)
            s = s.permute(2,0,3,1).contiguous()
            s = s.view(self.pt, batch_size * self.skip, self.hidC)
            _, s = self.GRUskip(s)
            s = s.view(batch_size, self.skip * self.hidS)
            s = self.dropout(s)
            r = torch.cat((r,s),1)
        
        res = self.linear1(r)
        res = res.view(batch_size, self.n_out, self.target_size)

        #highway
        if (self.hw > 0):
            z = x[:, -self.hw:, :]
            z = z.permute(0,2,1).contiguous().view(-1, self.hw)
            z = self.highway(z)
            z = z.view(-1, self.n_out, self.target_size)
            res = res + z
            
        if (self.output):
            res = self.output(res)
        return res


if __name__ == '__main__':

    n_in = timesteps = 32 # 16 # 20 # window size # input length # seq len 
    n_out = prediction_horizon = 5 # horizon # output length # target length
    input_size = n_timeseries = 3 #data.shape[1] # input dim # input size # num nodes # unit s
    batch_size = bs = 16 # batch_size

    model_list = [
        MTGNN(n_in=timesteps, n_out=prediction_horizon, input_size=n_timeseries, subgraph_size=n_timeseries, in_dim=1, node_dim=40, dilation_exponential=2, conv_channels=32, residual_channels=16, skip_channels=32, end_channels=64, layers=3, gcn_depth=5, propalpha=0.05, tanhalpha=3, gcn_true=True, buildA_true=True, predefined_A=None, static_feat=None, dropout=0.0, layer_norm_affline=True, device='cuda').cuda(),
        StemGNN(n_in=timesteps, n_out=prediction_horizon, input_size=n_timeseries, stack_cnt=2, multi_layer=5, dropout_rate=0, leaky_rate=0.2, device='cuda'),
        Transformer(n_in=timesteps, n_out=prediction_horizon, input_size=n_timeseries, target_size=n_timeseries, d_model=64, nhead=4, num_encoder_layers=3, num_decoder_layers=3, dim_feedforward=512, dropout=0., activation='relu', custom_encoder=None, custom_decoder=None).cuda(),
        ConvTrans(n_in=timesteps, n_out=prediction_horizon, input_size=n_timeseries, feature_size=256, kernel_size=9, nhead=8, num_layers=1, dropout=0).cuda(),
        _TCN(n_in=timesteps, n_out=prediction_horizon, input_size=n_timeseries, target_size=n_timeseries, kernel_size=7, num_filters=3, num_layers=None, dilation_base=2, weight_norm=False, dropout=0.).cuda(),
        SCINet(n_in = timesteps, n_out = prediction_horizon, input_size = n_timeseries, hid_size = 1, num_stacks = 1, num_levels = 3, concat_len = 0, groups = 1, kernel = 3, dropout = 0, single_step_output_One = 0, positionalE = True, modified = True, RIN = True).cuda(),
        FNN(n_in=timesteps, n_out=prediction_horizon, input_size=n_timeseries, target_size=n_timeseries, activation='ReLU', fc_layers=[512, 256, 128, 64]).cuda(),
        NBEATS(n_in=timesteps, n_out=prediction_horizon, input_size=n_timeseries, target_size=n_timeseries, generic_architecture=True, num_stacks=30, num_blocks=1, num_layers=4, layer_widths=256, expansion_coefficient_dim=5, trend_polynomial_degree=2).cuda(),
        NHITS(n_in=timesteps, n_out=prediction_horizon, input_size=n_timeseries, target_size=n_timeseries, num_stacks=30, num_blocks=1, num_layers=4, n_mlp_units=[200,300,200,100]).cuda(),
        RNN(n_in=timesteps, n_out=prediction_horizon, input_size=n_timeseries, target_size=n_timeseries, name='GRU', hidden_layer_size=100, num_layers=1, bidirectional=False).cuda(),
        ESN(n_in=timesteps, n_out=prediction_horizon, input_size=n_timeseries, target_size=n_timeseries, hidden_layer_size=100, num_layers=1).cuda(),
        LSTNet(n_in=timesteps, n_out=prediction_horizon, input_size=n_timeseries, target_size=n_timeseries).cuda(),
    ]

    x = torch.randn(batch_size, n_in, input_size).cuda()

    for model in model_list:
        y = model(x)
        print(y.shape) # (batch_size, n_out, input_size)

# refer SCINet-main\models\SCINet.py